diff --git a/detection_rules/devtools.py b/detection_rules/devtools.py index 247d2cee4..19ab42e02 100644 --- a/detection_rules/devtools.py +++ b/detection_rules/devtools.py @@ -25,7 +25,7 @@ from .eswrap import CollectEvents, add_range_to_dsl from .main import root from .misc import PYTHON_LICENSE, add_client, GithubClient, Manifest, client_error, getdefault from .packaging import PACKAGE_FILE, Package, manage_versions, RELEASE_DIR -from .rule import TOMLRule, BaseQueryRuleData +from .rule import TOMLRule, QueryRuleData from .rule_loader import production_filter, RuleCollection from .utils import get_path, dict_hash @@ -389,7 +389,7 @@ def rule_event_search(ctx, rule, date_range, count, max_results, verbose, elasticsearch_client: Elasticsearch = None): """Search using a rule file against an Elasticsearch instance.""" - if isinstance(rule.contents.data, BaseQueryRuleData): + if isinstance(rule.contents.data, QueryRuleData): if verbose: click.echo(f'Searching rule: {rule.name}') diff --git a/detection_rules/packaging.py b/detection_rules/packaging.py index 6b336dce6..e24f66a7a 100644 --- a/detection_rules/packaging.py +++ b/detection_rules/packaging.py @@ -18,7 +18,7 @@ import click import yaml from .misc import JS_LICENSE, cached -from .rule import TOMLRule, BaseQueryRuleData, ThreatMapping +from .rule import TOMLRule, QueryRuleData, ThreatMapping from .rule import downgrade_contents_from_rule from .rule_loader import RuleCollection, DEFAULT_RULES_DIR from .schemas import CurrentSchema, definitions @@ -377,7 +377,7 @@ class Package(object): def get_summary_rule_info(r: TOMLRule): r = r.contents rule_str = f'{r.name:<{longest_name}} (v:{r.autobumped_version} t:{r.data.type}' - if isinstance(rule.contents.data, BaseQueryRuleData): + if isinstance(rule.contents.data, QueryRuleData): rule_str += f'-{r.data.language}' rule_str += f'(indexes:{"".join(index_map[idx] for idx in rule.contents.data.index) or "none"}' @@ -387,7 +387,7 @@ class Package(object): # lookup the rule in the GitHub tag v{major.minor.patch} data = r.contents.data rules_dir_link = f'https://github.com/elastic/detection-rules/tree/v{self.name}/rules/{sd}/' - rule_type = data.language if isinstance(data, BaseQueryRuleData) else data.type + rule_type = data.language if isinstance(data, QueryRuleData) else data.type return f'`{r.id}` **[{r.name}]({rules_dir_link + os.path.basename(str(r.path))})** (_{rule_type}_)' for rule in self.rules: diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 6c16d223e..193b9e8cd 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -5,15 +5,14 @@ """Rule object.""" import json from dataclasses import dataclass, field +from functools import cached_property from pathlib import Path from typing import Literal, Union, Optional, List, Any from uuid import uuid4 -import eql from marshmallow import validates_schema -import kql -from . import ecs, beats, utils +from . import utils from .mixins import MarshmallowDataclassMixin from .rule_formatter import toml_write, nested_normalize from .schemas import definitions @@ -169,67 +168,48 @@ class BaseRuleData(MarshmallowDataclassMixin): type: Literal[definitions.RuleType] threat: Optional[List[ThreatMapping]] + def validate_query(self, meta: RuleMeta) -> None: + pass + + +@dataclass +class QueryValidator: + query: str + + @property + def ast(self) -> Any: + raise NotImplementedError + + def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: + raise NotImplementedError() + @dataclass(frozen=True) -class BaseQueryRuleData(BaseRuleData): +class QueryRuleData(BaseRuleData): """Specific fields for query event types.""" type: Literal["query"] index: Optional[List[str]] query: str - language: str + language: definitions.FilterLanguages - @property - def parsed_query(self) -> Optional[object]: - return None + @cached_property + def validator(self) -> Optional[QueryValidator]: + if self.language == "kuery": + return KQLValidator(self.query) + elif self.language == "eql": + return EQLValidator(self.query) + def validate_query(self, meta: RuleMeta) -> None: + validator = self.validator + if validator is not None: + return validator.validate(self, meta) -@dataclass(frozen=True) -class KQLRuleData(BaseQueryRuleData): - """Specific fields for query event types.""" - language: Literal["kuery"] - - @property - def parsed_query(self) -> kql.ast.Expression: - return kql.parse(self.query) - - @property - def unique_fields(self): - return list(set(str(f) for f in self.parsed_query if isinstance(f, kql.ast.Field))) - - def to_eql(self) -> eql.ast.Expression: - return kql.to_eql(self.query) - - def validate_query(self, beats_version: str, ecs_versions: List[str]): - """Static method to validate the query, called from the parent which contains [metadata] information.""" - indexes = self.index or [] - parsed = self.parsed_query - - beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index] - beat_schema = beats.get_schema_from_kql(parsed, beat_types, version=beats_version) if beat_types else None - - if not ecs_versions: - kql.parse(self.query, schema=ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema)) - else: - for version in ecs_versions: - schema = ecs.get_kql_schema(version=version, indexes=indexes, beat_schema=beat_schema) - - try: - kql.parse(self.query, schema=schema) - except kql.KqlParseError as exc: - message = exc.error_msg - trailer = None - if "Unknown field" in message and beat_types: - trailer = "\nTry adding event.module or event.dataset to specify beats module" - - raise kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source, - len(exc.caret.lstrip()), trailer=trailer) from None - - -@dataclass(frozen=True) -class LuceneRuleData(BaseQueryRuleData): - """Specific fields for query event types.""" - language: Literal["lucene"] + @cached_property + def ast(self): + validator = self.validator + if validator is not None: + return validator.ast @dataclass(frozen=True) @@ -241,7 +221,7 @@ class MachineLearningRuleData(BaseRuleData): @dataclass(frozen=True) -class ThresholdQueryRuleData(BaseQueryRuleData): +class ThresholdQueryRuleData(QueryRuleData): """Specific fields for query event types.""" @dataclass(frozen=True) @@ -256,57 +236,18 @@ class ThresholdQueryRuleData(BaseQueryRuleData): cardinality: Optional[ThresholdCardinality] type: Literal["threshold"] - language: Literal["kuery", "lucene"] threshold: ThresholdMapping @dataclass(frozen=True) -class EQLRuleData(BaseQueryRuleData): +class EQLRuleData(QueryRuleData): """EQL rules are a special case of query rules.""" type: Literal["eql"] - - @property - def parsed_query(self) -> kql.ast.Expression: - with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions: - return eql.parse_query(self.query) - - @property - def unique_fields(self): - return list(set(str(f) for f in self.parsed_query if isinstance(f, eql.ast.Field))) - - def validate_query(self, beats_version: str, ecs_versions: List[str]): - """Validate an EQL query while checking TOMLRule.""" - # TODO: remove once py-eql supports ipv6 for cidrmatch - # Or, unregister the cidrMatch function and replace it with one that doesn't validate against strict IPv4 - with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions: - parsed = eql.parse_query(self.query) - - beat_types = [index.split("-")[0] for index in self.index or [] if "beat-*" in index] - beat_schema = beats.get_schema_from_eql(parsed, beat_types, version=beats_version) if beat_types else None - - for version in ecs_versions: - schema = ecs.get_kql_schema(indexes=self.index or [], beat_schema=beat_schema, version=version) - - try: - # TODO: switch to custom cidrmatch that allows ipv6 - with ecs.KqlSchema2Eql(schema), eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions: - eql.parse_query(self.query) - - except eql.EqlTypeMismatchError: - raise - - except eql.EqlParseError as exc: - message = exc.error_msg - trailer = None - if "Unknown field" in message and beat_types: - trailer = "\nTry adding event.module or event.dataset to specify beats module" - - raise exc.__class__(exc.error_msg, exc.line, exc.column, exc.source, - len(exc.caret.lstrip()), trailer=trailer) from None + language: Literal["eql"] # All of the possible rule types -AnyRuleData = Union[KQLRuleData, LuceneRuleData, MachineLearningRuleData, ThresholdQueryRuleData, EQLRuleData] +AnyRuleData = Union[QueryRuleData, EQLRuleData, MachineLearningRuleData, ThresholdQueryRuleData] @dataclass(frozen=True) @@ -365,17 +306,7 @@ class TOMLRuleContents(MarshmallowDataclassMixin): data: AnyRuleData = value["data"] metadata: RuleMeta = value["metadata"] - beats_version = metadata.beats_version or beats.get_max_version() - ecs_versions = metadata.ecs_versions or [ecs.get_max_version()] - - # call into these validate methods - if isinstance(data, (EQLRuleData, KQLRuleData)): - if metadata.query_schema_validation is False or metadata.maturity == "deprecated": - # Check the syntax only - _ = data.parsed_query - else: - # otherwise, do a full schema validation - data.validate_query(beats_version=beats_version, ecs_versions=ecs_versions) + return data.validate_query(metadata) def to_dict(self, strip_none_values=True) -> dict: dict_obj = super(TOMLRuleContents, self).to_dict(strip_none_values=strip_none_values) @@ -454,3 +385,7 @@ def downgrade_contents_from_rule(rule: TOMLRule, target_version: str) -> dict: payload["rule_id"] = str(uuid4()) payload = downgrade(payload, target_version) return payload + + +# avoid a circular import +from .rule_validators import KQLValidator, EQLValidator # noqa: E402 diff --git a/detection_rules/rule_validators.py b/detection_rules/rule_validators.py new file mode 100644 index 000000000..825038e2a --- /dev/null +++ b/detection_rules/rule_validators.py @@ -0,0 +1,114 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. + +"""Validation logic for rules containing queries.""" +from functools import cached_property +from typing import List + +import eql + +import kql +from detection_rules import beats, ecs +from detection_rules.rule import QueryValidator, QueryRuleData, RuleMeta + + +class KQLValidator(QueryValidator): + """Specific fields for query event types.""" + + @cached_property + def ast(self) -> kql.ast.Expression: + return kql.parse(self.query) + + @property + def unique_fields(self) -> List[str]: + return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field))) + + def to_eql(self) -> eql.ast.Expression: + return kql.to_eql(self.query) + + def validate(self, data: QueryRuleData, meta: RuleMeta) -> None: + """Static method to validate the query, called from the parent which contains [metadata] information.""" + ast = self.ast + + if meta.query_schema_validation is False or meta.maturity == "deprecated": + # syntax only, which is done via self.ast + return + + indexes = data.index or [] + beats_version = meta.beats_version or beats.get_max_version() + ecs_versions = meta.ecs_versions or [ecs.get_max_version()] + + beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index] + beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None + + if not ecs_versions: + kql.parse(self.query, schema=ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema)) + else: + for version in ecs_versions: + schema = ecs.get_kql_schema(version=version, indexes=indexes, beat_schema=beat_schema) + + try: + kql.parse(self.query, schema=schema) + except kql.KqlParseError as exc: + message = exc.error_msg + trailer = None + if "Unknown field" in message and beat_types: + trailer = "\nTry adding event.module or event.dataset to specify beats module" + + raise kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source, + len(exc.caret.lstrip()), trailer=trailer) from None + + +class EQLValidator(QueryValidator): + + @cached_property + def ast(self) -> kql.ast.Expression: + with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions: + return eql.parse_query(self.query) + + @property + def unique_fields(self) -> List[str]: + return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field))) + + def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: + """Validate an EQL query while checking TOMLRule.""" + _ = self.ast + + if meta.query_schema_validation is False or meta.maturity == "deprecated": + # syntax only, which is done via self.ast + return + + indexes = data.index or [] + beats_version = meta.beats_version or beats.get_max_version() + ecs_versions = meta.ecs_versions or [ecs.get_max_version()] + + # TODO: remove once py-eql supports ipv6 for cidrmatch + # Or, unregister the cidrMatch function and replace it with one that doesn't validate against strict IPv4 + with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions: + parsed = eql.parse_query(self.query) + + beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index] + beat_schema = beats.get_schema_from_eql(parsed, beat_types, version=beats_version) if beat_types else None + + for version in ecs_versions: + schema = ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema, version=version) + eql_schema = ecs.KqlSchema2Eql(schema) + + try: + # TODO: switch to custom cidrmatch that allows ipv6 + with eql_schema, eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions: + eql.parse_query(self.query) + + except eql.EqlTypeMismatchError: + raise + + except eql.EqlParseError as exc: + message = exc.error_msg + trailer = None + if "Unknown field" in message and beat_types: + trailer = "\nTry adding event.module or event.dataset to specify beats module" + + raise exc.__class__(exc.error_msg, exc.line, exc.column, exc.source, + len(exc.caret.lstrip()), trailer=trailer) from None diff --git a/detection_rules/schemas/definitions.py b/detection_rules/schemas/definitions.py index a5c42e872..f92ad9a66 100644 --- a/detection_rules/schemas/definitions.py +++ b/detection_rules/schemas/definitions.py @@ -40,6 +40,7 @@ CodeString = NewType("CodeString", str) ConditionSemVer = NewType('ConditionSemVer', str, validate=validate.Regexp(CONDITION_VERSION_PATTERN)) Date = NewType('Date', str, validate=validate.Regexp(DATE_PATTERN)) Interval = NewType('Interval', str, validate=validate.Regexp(INTERVAL_PATTERN)) +FilterLanguages = Literal["kuery", "lucene"] Markdown = NewType("MarkdownField", CodeString) Maturity = Literal['development', 'experimental', 'beta', 'production', 'deprecated'] MaxSignals = NewType("MaxSignals", int, validate=validate.Range(min=1)) diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index 0ae9eb8cd..0faadf082 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -14,7 +14,7 @@ import eql import kql from detection_rules import attack, beats, ecs from detection_rules.packaging import load_versions -from detection_rules.rule import BaseQueryRuleData +from detection_rules.rule import QueryRuleData from detection_rules.rule_loader import FILE_PATTERN from detection_rules.utils import get_path, load_etc_dump from rta import get_ttp_names @@ -58,7 +58,7 @@ class TestValidRules(BaseRuleTest): ttp_names = get_ttp_names() for rule in self.production_rules: - if isinstance(rule.contents.data, BaseQueryRuleData) and rule.id in mappings: + if isinstance(rule.contents.data, QueryRuleData) and rule.id in mappings: matching_rta = mappings[rule.id].get('rta_name') self.assertIsNotNone(matching_rta, f'{self.rule_str(rule)} does not have RTAs') @@ -232,7 +232,7 @@ class TestRuleTags(BaseRuleTest): if 'Elastic' not in rule_tags: missing_required_tags.add('Elastic') - if isinstance(rule.contents.data, BaseQueryRuleData): + if isinstance(rule.contents.data, QueryRuleData): for index in rule.contents.data.index: expected_tags = required_tags_map.get(index, {}) expected_all = expected_tags.get('all', []) @@ -440,13 +440,13 @@ class TestTuleTiming(BaseRuleTest): for rule in self.all_rules: required = False - if isinstance(rule.contents.data, BaseQueryRuleData) and 'endgame-*' in rule.contents.data.index: + if isinstance(rule.contents.data, QueryRuleData) and 'endgame-*' in rule.contents.data.index: continue if rule.contents.data.type == 'query': required = True elif rule.contents.data.type == 'eql' and \ - eql.utils.get_query_type(rule.contents.data.parsed_query) != 'sequence': + eql.utils.get_query_type(rule.contents.data.ast) != 'sequence': required = True if required and rule.contents.data.timestamp_override != 'event.ingested': @@ -465,7 +465,7 @@ class TestTuleTiming(BaseRuleTest): for rule in self.all_rules: contents = rule.contents - if isinstance(contents.data, BaseQueryRuleData): + if isinstance(contents.data, QueryRuleData): if set(getattr(contents.data, "index", None) or []) & long_indexes and not contents.data.from_: missing.append(rule) diff --git a/tests/test_mappings.py b/tests/test_mappings.py index 48f749be7..5e441c539 100644 --- a/tests/test_mappings.py +++ b/tests/test_mappings.py @@ -7,7 +7,6 @@ import copy import warnings -from detection_rules.rule import KQLRuleData from . import get_data_files, get_fp_data_files from detection_rules.utils import combine_sources, evaluate, load_etc_dump from .base import BaseRuleTest @@ -30,7 +29,7 @@ class TestMappings(BaseRuleTest): mappings = load_etc_dump('rule-mapping.yml') for rule in self.production_rules: - if isinstance(rule.contents.data, KQLRuleData): + if rule.contents.data.type == "query" and rule.contents.data.language == "kuery": if rule.id not in mappings: continue @@ -63,7 +62,7 @@ class TestMappings(BaseRuleTest): def test_false_positives(self): """Test that expected results return against false positives.""" for rule in self.production_rules: - if isinstance(rule.contents.data, KQLRuleData): + if rule.contents.data.type == "query" and rule.contents.data.language == "kuery": for fp_name, merged_data in get_fp_data_files().items(): msg = 'Unexpected FP match for: {} - {}, against: {}'.format(rule.id, rule.name, fp_name) self.evaluate(copy.deepcopy(merged_data), rule, 0, msg)