From a4a0bc6a7e78f70bbcb95833059f8f26b0d8c2dd Mon Sep 17 00:00:00 2001 From: Eric Forte <119343520+eric-forte-elastic@users.noreply.github.com> Date: Mon, 6 May 2024 07:58:42 -0400 Subject: [PATCH] [Bug] Query validation failing to capture InSet edge case with ip field types (#3572) * Move test case to separate file --------- Co-authored-by: Mika Ayenson Co-authored-by: shashank-elastic <91139415+shashank-elastic@users.noreply.github.com> --- detection_rules/rule_validators.py | 75 +++++++++++++++++- tests/test_python_library.py | 68 ++++++++++++++++ tests/test_specific_rules.py | 120 ++++++++++++++++++----------- 3 files changed, 214 insertions(+), 49 deletions(-) create mode 100644 tests/test_python_library.py diff --git a/detection_rules/rule_validators.py b/detection_rules/rule_validators.py index 6631db83e..a75a13bda 100644 --- a/detection_rules/rule_validators.py +++ b/detection_rules/rule_validators.py @@ -4,10 +4,14 @@ # 2.0. """Validation logic for rules containing queries.""" -from functools import cached_property -from typing import List, Optional, Tuple, Union +from enum import Enum +from functools import cached_property, wraps +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import eql +from eql import ast +from eql.parser import KvTree, LarkToEQL, NodeInfo, TypeHint +from eql.parser import _parse as base_parse from marshmallow import ValidationError from semver import Version @@ -31,6 +35,73 @@ EQL_ERROR_TYPES = Union[eql.EqlCompileError, KQL_ERROR_TYPES = Union[kql.KqlCompileError, kql.KqlParseError] +class ExtendedTypeHint(Enum): + IP = "ip" + + @classmethod + def primitives(cls): + """Get all primitive types.""" + return TypeHint.Boolean, TypeHint.Numeric, TypeHint.Null, TypeHint.String, ExtendedTypeHint.IP + + def is_primitive(self): + """Check if a type is a primitive.""" + return self in self.primitives() + + +def custom_in_set(self, node: KvTree) -> NodeInfo: + """Override and address the limitations of the eql in_set method.""" + # return BaseInSetMethod(self, node) + outer, container = self.visit(node.child_trees) # type: (NodeInfo, list[NodeInfo]) + + if not outer.validate_type(ExtendedTypeHint.primitives()): + # can't compare non-primitives to sets + raise self._type_error(outer, ExtendedTypeHint.primitives()) + + # Check that everything inside the container has the same type as outside + error_message = "Unable to compare {expected_type} to {actual_type}" + for inner in container: + if not inner.validate_type(outer): + raise self._type_error(inner, outer, error_message) + + if self._elasticsearch_syntax and hasattr(outer, "type_info"): + # Check edge case of in_set and ip/string comparison + outer_type = outer.type_info + if isinstance(self._schema, ecs.KqlSchema2Eql): + type_hint = self._schema.kql_schema.get(str(outer.node), "unknown") + if hasattr(self._schema, "type_mapping") and type_hint == "ip": + outer.type_info = ExtendedTypeHint.IP + for inner in container: + if not inner.validate_type(outer): + raise self._type_error(inner, outer, error_message) + + # reset the type + outer.type_info = outer_type + + # This will always evaluate to true/false, so it should be a boolean + term = ast.InSet(outer.node, [c.node for c in container]) + nullable = outer.nullable or any(c.nullable for c in container) + return NodeInfo(term, TypeHint.Boolean, nullable=nullable, source=node) + + +def custom_base_parse_decorator(func: Callable[..., Any]) -> Callable[..., Any]: + """Override and address the limitations of the eql in_set method.""" + + @wraps(func) + def wrapper(query: str, start: Optional[str] = None, **kwargs: Dict[str, Any]) -> Any: + original_in_set = LarkToEQL.in_set + LarkToEQL.in_set = custom_in_set + try: + result = func(query, start=start, **kwargs) + finally: # Using finally to ensure that the original method is restored + LarkToEQL.in_set = original_in_set + return result + + return wrapper + + +eql.parser._parse = custom_base_parse_decorator(base_parse) + + class KQLValidator(QueryValidator): """Specific fields for KQL query event types.""" diff --git a/tests/test_python_library.py b/tests/test_python_library.py new file mode 100644 index 000000000..c82f43492 --- /dev/null +++ b/tests/test_python_library.py @@ -0,0 +1,68 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. + +from detection_rules.rule_loader import RuleCollection + +from .base import BaseRuleTest + + +class TestEQLInSet(BaseRuleTest): + """Test EQL rule query in set override.""" + + def test_eql_in_set(self): + """Test that the query validation is working correctly.""" + rc = RuleCollection() + eql_rule = { + "metadata": { + "creation_date": "2020/12/15", + "integration": ["endpoint", "windows"], + "maturity": "production", + "min_stack_comments": "New fields added: required_fields, related_integrations, setup", + "min_stack_version": "8.3.0", + "updated_date": "2024/03/26", + }, + "rule": { + "author": ["Elastic"], + "description": """ + Test Rule. + """, + "false_positives": ["Fake."], + "from": "now-9m", + "index": ["winlogbeat-*", "logs-endpoint.events.*", "logs-windows.sysmon_operational-*"], + "language": "eql", + "license": "Elastic License v2", + "name": "Fake Test Rule", + "references": [ + "https://example.com", + ], + "risk_score": 47, + "rule_id": "4fffae5d-8b7d-4e48-88b1-979ed42fd9a3", + "severity": "medium", + "tags": [ + "Domain: Endpoint", + "OS: Windows", + "Use Case: Threat Detection", + "Tactic: Execution", + "Data Source: Elastic Defend", + "Data Source: Sysmon", + ], + "type": "eql", + "query": """ + sequence by host.id, process.entity_id with maxspan = 5s + [network where destination.ip in ("127.0.0.1", "::1")] + """, + }, + } + expected_error_message = r"Error in both stack and integrations checks:.*Unable to compare ip to string.*" + with self.assertRaisesRegex(ValueError, expected_error_message): + rc.load_dict(eql_rule) + # Change to appropriate destination.address field + eql_rule["rule"][ + "query" + ] = """ + sequence by host.id, process.entity_id with maxspan = 10s + [network where destination.address in ("192.168.1.1", "::1")] + """ + rc.load_dict(eql_rule) diff --git a/tests/test_specific_rules.py b/tests/test_specific_rules.py index f844f89f4..318f33660 100644 --- a/tests/test_specific_rules.py +++ b/tests/test_specific_rules.py @@ -13,7 +13,9 @@ from semver import Version import kql from detection_rules.integrations import ( - find_latest_compatible_version, load_integrations_manifests, load_integrations_schemas + find_latest_compatible_version, + load_integrations_manifests, + load_integrations_schemas, ) from detection_rules.misc import load_current_package_version from detection_rules.packaging import current_stack_version @@ -23,31 +25,34 @@ from detection_rules.schemas import get_stack_schemas from detection_rules.utils import get_path, load_rule_contents from .base import BaseRuleTest + PACKAGE_STACK_VERSION = Version.parse(current_stack_version(), optional_minor_and_patch=True) class TestEndpointQuery(BaseRuleTest): """Test endpoint-specific rules.""" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"), - "Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.3.0"), + "Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0", + ) def test_os_and_platform_in_query(self): """Test that all endpoint rules have an os defined and linux includes platform.""" for rule in self.production_rules: - if not rule.contents.data.get('language') in ('eql', 'kuery'): + if not rule.contents.data.get("language") in ("eql", "kuery"): continue - if rule.path.parent.name not in ('windows', 'macos', 'linux'): + if rule.path.parent.name not in ("windows", "macos", "linux"): # skip cross-platform for now continue ast = rule.contents.data.ast fields = [str(f) for f in ast if isinstance(f, (kql.ast.Field, eql.ast.Field))] - err_msg = f'{self.rule_str(rule)} missing required field for endpoint rule' - if 'host.os.type' not in fields: + err_msg = f"{self.rule_str(rule)} missing required field for endpoint rule" + if "host.os.type" not in fields: # Exception for Forwarded Events which contain Windows-only fields. - if rule.path.parent.name == 'windows' and not any(field.startswith('winlog.') for field in fields): - self.assertIn('host.os.type', fields, err_msg) + if rule.path.parent.name == "windows" and not any(field.startswith("winlog.") for field in fields): + self.assertIn("host.os.type", fields, err_msg) # going to bypass this for now # if rule.path.parent.name == 'linux': @@ -58,8 +63,9 @@ class TestEndpointQuery(BaseRuleTest): class TestNewTerms(BaseRuleTest): """Test new term rules.""" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature." + ) def test_history_window_start(self): """Test new terms history window start field.""" @@ -67,39 +73,49 @@ class TestNewTerms(BaseRuleTest): if rule.contents.data.type == "new_terms": # validate history window start field exists and is correct - assert rule.contents.data.new_terms.history_window_start, \ - "new terms field found with no history_window_start field defined" - assert rule.contents.data.new_terms.history_window_start[0].field == "history_window_start", \ - f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'" + assert ( + rule.contents.data.new_terms.history_window_start + ), "new terms field found with no history_window_start field defined" + assert ( + rule.contents.data.new_terms.history_window_start[0].field == "history_window_start" + ), f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature." + ) def test_new_terms_field_exists(self): # validate new terms and history window start fields are correct for rule in self.production_rules: if rule.contents.data.type == "new_terms": - assert rule.contents.data.new_terms.field == "new_terms_fields", \ - f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type" + assert ( + rule.contents.data.new_terms.field == "new_terms_fields" + ), f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature." + ) def test_new_terms_fields(self): """Test new terms fields are schema validated.""" # ecs validation for rule in self.production_rules: if rule.contents.data.type == "new_terms": meta = rule.contents.metadata - feature_min_stack = Version.parse('8.4.0') + feature_min_stack = Version.parse("8.4.0") current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - min_stack_version = Version.parse(meta.get("min_stack_version")) if \ - meta.get("min_stack_version") else None - min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \ - current_package_version else min_stack_version + min_stack_version = ( + Version.parse(meta.get("min_stack_version")) if meta.get("min_stack_version") else None + ) + min_stack_version = ( + current_package_version + if min_stack_version is None or min_stack_version < current_package_version + else min_stack_version + ) - assert min_stack_version >= feature_min_stack, \ - f"New Terms rule types only compatible with {feature_min_stack}+" - ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs'] - beats_version = get_stack_schemas()[str(min_stack_version)]['beats'] + assert ( + min_stack_version >= feature_min_stack + ), f"New Terms rule types only compatible with {feature_min_stack}+" + ecs_version = get_stack_schemas()[str(min_stack_version)]["ecs"] + beats_version = get_stack_schemas()[str(min_stack_version)]["beats"] # checks if new terms field(s) are in ecs, beats non-ecs or integration schemas queryvalidator = QueryValidator(rule.contents.data.query) @@ -113,43 +129,53 @@ class TestNewTerms(BaseRuleTest): package=tag, integration="", rule_stack_version=min_stack_version, - packages_manifest=integration_manifests) + packages_manifest=integration_manifests, + ) if latest_tag_compat_ver: integration_schema = integration_schemas[tag][latest_tag_compat_ver] for policy_template in integration_schema.keys(): schema.update(**integration_schemas[tag][latest_tag_compat_ver][policy_template]) for new_terms_field in rule.contents.data.new_terms.value: - assert new_terms_field in schema.keys(), \ - f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas" + assert ( + new_terms_field in schema.keys() + ), f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature." + ) def test_new_terms_max_limit(self): """Test new terms max limit.""" # validates length of new_terms to stack version - https://github.com/elastic/kibana/issues/142862 for rule in self.production_rules: if rule.contents.data.type == "new_terms": meta = rule.contents.metadata - feature_min_stack = Version.parse('8.4.0') - feature_min_stack_extended_fields = Version.parse('8.6.0') + feature_min_stack = Version.parse("8.4.0") + feature_min_stack_extended_fields = Version.parse("8.6.0") current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - min_stack_version = Version.parse(meta.get("min_stack_version")) if \ - meta.get("min_stack_version") else None - min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \ - current_package_version else min_stack_version + min_stack_version = ( + Version.parse(meta.get("min_stack_version")) if meta.get("min_stack_version") else None + ) + min_stack_version = ( + current_package_version + if min_stack_version is None or min_stack_version < current_package_version + else min_stack_version + ) if feature_min_stack <= min_stack_version < feature_min_stack_extended_fields: - assert len(rule.contents.data.new_terms.value) == 1, \ - f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}" + assert ( + len(rule.contents.data.new_terms.value) == 1 + ), f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.6.0"), "Test only applicable to 8.4+ stacks for new terms feature." + ) def test_new_terms_fields_unique(self): """Test new terms fields are unique.""" # validate fields are unique for rule in self.production_rules: if rule.contents.data.type == "new_terms": - assert len(set(rule.contents.data.new_terms.value)) == len(rule.contents.data.new_terms.value), \ - f"new terms fields values are not unique - {rule.contents.data.new_terms.value}" + assert len(set(rule.contents.data.new_terms.value)) == len( + rule.contents.data.new_terms.value + ), f"new terms fields values are not unique - {rule.contents.data.new_terms.value}" class TestESQLRules(BaseRuleTest):