diff --git a/detection_rules/beats.py b/detection_rules/beats.py index 23f3c1a45..57f5c3b3d 100644 --- a/detection_rules/beats.py +++ b/detection_rules/beats.py @@ -6,11 +6,12 @@ import os import kql +import eql import requests import yaml from .semver import Version -from .utils import unzip, load_etc_dump, save_etc_dump, get_etc_path +from .utils import unzip, load_etc_dump, save_etc_dump, get_etc_path, cached def download_latest_beats_schema(): @@ -129,34 +130,16 @@ def get_beats_sub_schema(schema: dict, beat: str, module: str, *datasets: str): return {field["name"]: field for field in sorted(flattened, key=lambda f: f["name"])} -SCHEMA = None - - +@cached def read_beats_schema(): - global SCHEMA + beats_schemas = os.listdir(get_etc_path("beats_schemas")) + latest = max(beats_schemas, key=lambda b: Version(b.lstrip("v"))) - if SCHEMA is None: - beats_schemas = os.listdir(get_etc_path("beats_schemas")) - latest = max(beats_schemas, key=lambda b: Version(b.lstrip("v"))) - - SCHEMA = load_etc_dump("beats_schemas", latest) - - return SCHEMA + return load_etc_dump("beats_schemas", latest) -def get_schema_for_query(tree: kql.ast, beats: list) -> dict: +def get_schema_from_datasets(beats, modules, datasets): filtered = {} - modules = set() - datasets = set() - - # extract out event.module and event.dataset from the query's AST - for node in tree: - if isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.module"): - modules.update(child.value for child in node.value if isinstance(child, kql.ast.String)) - - if isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.dataset"): - datasets.update(child.value for child in node.value if isinstance(child, kql.ast.String)) - beats_schema = read_beats_schema() # infer the module if only a dataset are defined @@ -173,3 +156,39 @@ def get_schema_for_query(tree: kql.ast, beats: list) -> dict: filtered.update(get_beats_sub_schema(beats_schema, beat, module, *datasets)) return filtered + + +def get_schema_from_eql(tree: eql.ast.BaseNode, beats: list) -> dict: + modules = set() + datasets = set() + + # extract out event.module and event.dataset from the query's AST + for node in tree: + if isinstance(node, eql.ast.Comparison) and node.comparator == node.EQ and \ + isinstance(node.right, eql.ast.String): + if node.left == eql.ast.Field("event", ["module"]): + modules.add(node.right.render()) + elif node.left == eql.ast.Field("event", ["dataset"]): + datasets.add(node.right.render()) + elif isinstance(node, eql.ast.InSet): + if node.expression == eql.ast.Field("event", ["module"]): + modules.add(node.get_literals()) + elif node.expression == eql.ast.Field("event", ["dataset"]): + datasets.add(node.get_literals()) + + return get_schema_from_datasets(beats, modules, datasets) + + +def get_schema_from_kql(tree: kql.ast.BaseNode, beats: list) -> dict: + modules = set() + datasets = set() + + # extract out event.module and event.dataset from the query's AST + for node in tree: + if isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.module"): + modules.update(child.value for child in node.value if isinstance(child, kql.ast.String)) + + if isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.dataset"): + datasets.update(child.value for child in node.value if isinstance(child, kql.ast.String)) + + return get_schema_from_datasets(beats, modules, datasets) diff --git a/detection_rules/ecs.py b/detection_rules/ecs.py index 2bb8c7d16..f3eaaefff 100644 --- a/detection_rules/ecs.py +++ b/detection_rules/ecs.py @@ -10,6 +10,8 @@ import shutil import json import requests +import eql +import eql.types import yaml from .semver import Version @@ -164,6 +166,34 @@ def flatten_multi_fields(schema): return converted +class KqlSchema2Eql(eql.Schema): + type_mapping = { + "keyword": eql.types.TypeHint.String, + "ip": eql.types.TypeHint.String, + "float": eql.types.TypeHint.Numeric, + "double": eql.types.TypeHint.Numeric, + "long": eql.types.TypeHint.Numeric, + "short": eql.types.TypeHint.Numeric, + } + + def __init__(self, kql_schema): + self.kql_schema = kql_schema + eql.Schema.__init__(self, {}, allow_any=True, allow_generic=False, allow_missing=False) + + def validate_event_type(self, event_type): + # allow all event types to fill in X: + # `X` where .... + return True + + def get_event_type_hint(self, event_type, path): + dotted = ".".join(path) + elasticsearch_type = self.kql_schema.get(dotted) + eql_hint = self.type_mapping.get(elasticsearch_type) + + if eql_hint is not None: + return eql_hint, None + + @cached def get_kql_schema(version=None, indexes=None, beat_schema=None): """Get schema for KQL.""" diff --git a/detection_rules/misc.py b/detection_rules/misc.py index 768df0121..d3a7f3af9 100644 --- a/detection_rules/misc.py +++ b/detection_rules/misc.py @@ -75,6 +75,9 @@ def schema_prompt(name, value=None, required=False, **options): if name == 'rule_id': default = str(uuid.uuid4()) + if len(enum) == 1 and required and field_type != "array": + return enum[0] + def _check_type(_val): if field_type in ('number', 'integer') and not str(_val).isdigit(): print('Number expected but got: {}'.format(_val)) diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 43f53acad..7ff68ff03 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -10,6 +10,7 @@ import os import click import kql +import eql from . import ecs, beats from .attack import TACTICS, build_threat_map_entry, technique_lookup @@ -70,9 +71,12 @@ class Rule(object): return self.contents.get('query') @property - def parsed_kql(self): - if self.query and self.contents['language'] == 'kuery': - return kql.parse(self.query) + def parsed_query(self): + if self.query: + if self.contents['language'] == 'kuery': + return kql.parse(self.query) + elif self.contents['language'] == 'eql': + return eql.parse_query(self.query) @property def filters(self): @@ -152,10 +156,50 @@ class Rule(object): schema_cls.validate(contents, role=self.type) - if query and self.query and self.contents['language'] == 'kuery': + if query and self.query is not None: ecs_versions = self.metadata.get('ecs_version') indexes = self.contents.get("index", []) - self._validate_kql(ecs_versions, indexes, self.query, self.name) + + if self.contents['language'] == 'kuery': + self._validate_kql(ecs_versions, indexes, self.query, self.name) + + if self.contents['language'] == 'eql': + self._validate_eql(ecs_versions, indexes, self.query, self.name) + + @staticmethod + @cached + def _validate_eql(ecs_versions, indexes, query, name): + # validate against all specified schemas or the latest if none specified + parsed = eql.parse_query(query) + beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index] + beat_schema = beats.get_schema_from_eql(parsed, beat_types) if beat_types else None + + ecs_versions = ecs_versions or [ecs_versions] + schemas = [] + + for version in ecs_versions: + try: + schemas.append(ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema, version=version)) + except KeyError: + raise KeyError('Unknown ecs schema version: {} in rule {}.\n' + 'Do you need to update schemas?'.format(version, name)) from None + + for schema in schemas: + try: + with ecs.KqlSchema2Eql(schema): + eql.parse_query(query) + + except eql.EqlTypeMismatchError: + raise + + except eql.EqlParseError as exc: + message = exc.error_msg + trailer = None + if "Unknown field" in message and beat_types: + trailer = "\nTry adding event.module and event.dataset to specify beats module" + + raise type(exc)(exc.error_msg, exc.line, exc.column, exc.source, + len(exc.caret.lstrip()), trailer=trailer) from None @staticmethod @cached @@ -163,7 +207,7 @@ class Rule(object): # validate against all specified schemas or the latest if none specified parsed = kql.parse(query) beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index] - beat_schema = beats.get_schema_for_query(parsed, beat_types) if beat_types else None + beat_schema = beats.get_schema_from_kql(parsed, beat_types) if beat_types else None if not ecs_versions: kql.parse(query, schema=ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema)) diff --git a/detection_rules/rule_loader.py b/detection_rules/rule_loader.py index d52d012ce..e309f37bb 100644 --- a/detection_rules/rule_loader.py +++ b/detection_rules/rule_loader.py @@ -93,12 +93,13 @@ def load_rules(file_lookup=None, verbose=True, error=True): raise KeyError("Rule has duplicate name to {}".format( next(r for r in rules if r.name == rule.name).path)) - if rule.parsed_kql: - if rule.parsed_kql in queries: + parsed_query = rule.parsed_query + if parsed_query is not None: + if parsed_query in queries: raise KeyError("Rule has duplicate query with {}".format( - next(r for r in rules if r.parsed_kql == rule.parsed_kql).path)) + next(r for r in rules if r.parsed_query == parsed_query).path)) - queries.append(rule.parsed_kql) + queries.append(parsed_query) if not re.match(FILE_PATTERN, os.path.basename(rule.path)): raise ValueError(f"Rule {rule.path} does not meet rule name standard of {FILE_PATTERN}") diff --git a/detection_rules/schemas/__init__.py b/detection_rules/schemas/__init__.py index 0d74dfa50..d29352d6a 100644 --- a/detection_rules/schemas/__init__.py +++ b/detection_rules/schemas/__init__.py @@ -7,8 +7,9 @@ from .rta_schema import validate_rta_mapping from ..semver import Version # import all of the schema versions -from .v78 import ApiSchema78 -from .v79 import ApiSchema79 +from .v7_8 import ApiSchema78 +from .v7_9 import ApiSchema79 +from .v7_10 import ApiSchema710 __all__ = ( "all_schemas", @@ -21,9 +22,10 @@ __all__ = ( all_schemas = [ ApiSchema78, ApiSchema79, + ApiSchema710, ] -CurrentSchema = max(all_schemas, key=lambda cls: Version(cls.STACK_VERSION)) +CurrentSchema = all_schemas[-1] def downgrade(api_contents: dict, target_version: str): diff --git a/detection_rules/schemas/v7_10.py b/detection_rules/schemas/v7_10.py new file mode 100644 index 000000000..c2bc2c137 --- /dev/null +++ b/detection_rules/schemas/v7_10.py @@ -0,0 +1,36 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +"""Definitions for rule metadata and schemas.""" + +import jsl +from .v7_9 import ApiSchema79 + + +# rule types +EQL = "eql" + + +class ApiSchema710(ApiSchema79): + """Schema for siem rule in API format.""" + + STACK_VERSION = "7.10" + RULE_TYPES = ApiSchema79.RULE_TYPES + [EQL] + + type = jsl.StringField(enum=RULE_TYPES, required=True) + + # there might be a bug in jsl that requires us to redefine these here + query_scope = ApiSchema79.query_scope + saved_id_scope = ApiSchema79.saved_id_scope + ml_scope = ApiSchema79.ml_scope + threshold_scope = ApiSchema79.threshold_scope + + with jsl.Scope(EQL) as eql_scope: + eql_scope.index = jsl.ArrayField(jsl.StringField(), required=False) + eql_scope.query = jsl.StringField(required=True) + eql_scope.language = jsl.StringField(enum=[EQL], required=True, default=EQL) + eql_scope.type = jsl.StringField(enum=[EQL], required=True) + + with jsl.Scope(jsl.DEFAULT_ROLE) as default_scope: + default_scope.type = type diff --git a/detection_rules/schemas/v78.py b/detection_rules/schemas/v7_8.py similarity index 100% rename from detection_rules/schemas/v78.py rename to detection_rules/schemas/v7_8.py diff --git a/detection_rules/schemas/v79.py b/detection_rules/schemas/v7_9.py similarity index 99% rename from detection_rules/schemas/v79.py rename to detection_rules/schemas/v7_9.py index 0dc8c97fe..87a2b2d05 100644 --- a/detection_rules/schemas/v79.py +++ b/detection_rules/schemas/v7_9.py @@ -5,7 +5,7 @@ """Definitions for rule metadata and schemas.""" import jsl -from .v78 import ApiSchema78 +from .v7_8 import ApiSchema78 OPERATORS = ['equals'] diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 225d86a9e..08966862c 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -5,6 +5,7 @@ """Test stack versioned schemas.""" import unittest import uuid +import eql from detection_rules.rule import Rule from detection_rules.schemas import downgrade, CurrentSchema @@ -106,3 +107,36 @@ class TestSchemas(unittest.TestCase): with self.assertRaisesRegex(ValueError, "Unsupported rule type"): downgrade(api_contents, "7.8") + + def test_eql_validation(self): + base_fields = { + "author": ["Elastic"], + "description": "test description", + "index": ["filebeat-*"], + "language": "eql", + "license": "Elastic License", + "name": "test rule", + "risk_score": 21, + "rule_id": str(uuid.uuid4()), + "severity": "low", + "type": "eql" + } + + Rule("test.toml", dict(base_fields, query=""" + process where process.name == "cmd.exe" + """)) + + with self.assertRaises(eql.EqlSyntaxError): + Rule("test.toml", dict(base_fields, query=""" + process where process.name == this!is$not#v@lid + """)) + + with self.assertRaises(eql.EqlSemanticError): + Rule("test.toml", dict(base_fields, query=""" + process where process.invalid_field == "hello world" + """)) + + with self.assertRaises(eql.EqlTypeMismatchError): + Rule("test.toml", dict(base_fields, query=""" + process where process.pid == "some string field" + """))