diff --git a/detection_rules/rule.py b/detection_rules/rule.py index c7624478e..091e1f11a 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -991,8 +991,8 @@ class ESQLRuleData(QueryRuleData): # Enforce KEEP command for ESQL rules and that METADATA fields are present in non-aggregate queries # Match | followed by optional whitespace/newlines and then 'keep' keep_pattern = re.compile(r"\|\s*keep\b\s+([^\|]+)", re.IGNORECASE | re.DOTALL) - keep_match = keep_pattern.search(query_lower) - if not keep_match: + keep_matches = list(keep_pattern.finditer(query_lower)) + if not keep_matches: raise EsqlSemanticError( f"Rule: {data['name']} does not contain a 'keep' command -> Add a 'keep' command to the query." ) @@ -1000,16 +1000,17 @@ class ESQLRuleData(QueryRuleData): # Ensure that keep clause includes metadata fields on non-aggregate queries aggregate_pattern = re.compile(r"\|\s*stats\b(?:\s+([^\|]+?))?(?:\s+by\s+([^\|]+))?", re.IGNORECASE | re.DOTALL) if not aggregate_pattern.search(query_lower): - raw_keep = re.sub(r"//.*", "", keep_match.group(1)) - keep_fields = [field.strip() for field in raw_keep.split(",") if field.strip()] - if "*" not in keep_fields: - required_metadata = {"_id", "_version", "_index"} - if not required_metadata.issubset(set(map(str.strip, keep_fields))): - raise EsqlSemanticError( - f"Rule: {data['name']} contains a keep clause without" - f" metadata fields '_id', '_version', and '_index' ->" - f" Add '_id', '_version', '_index' to the keep command." - ) + for keep_match in keep_matches: + raw_keep = re.sub(r"//.*", "", keep_match.group(1)) + keep_fields = [field.strip() for field in raw_keep.split(",") if field.strip()] + if "*" not in keep_fields: + required_metadata = {"_id", "_version", "_index"} + if not required_metadata.issubset(set(map(str.strip, keep_fields))): + raise EsqlSemanticError( + f"Rule: {data['name']} contains a keep clause without" + f" metadata fields '_id', '_version', and '_index' ->" + f" Add '_id', '_version', '_index' to the keep command." + ) @dataclass(frozen=True, kw_only=True) @@ -1261,7 +1262,7 @@ class BaseRuleContents(ABC): return obj def _uses_keep_star(self, hashable_dict: dict[str, Any]) -> bool: - """Check if this is an ES|QL rule that uses `| keep *`.""" + """Check if this is an ES|QL rule that uses `| keep *` or fields ending with '*'.""" if hashable_dict.get("language") != "esql": return False @@ -1273,7 +1274,7 @@ class BaseRuleContents(ABC): keep_match: re.Match[str] | None = keep_pattern.search(query) if keep_match: keep_fields: list[str] = [field.strip() for field in keep_match.group(1).split(",")] - return "*" in keep_fields + return any(field == "*" or field.endswith("*") for field in keep_fields) return False @abstractmethod diff --git a/pyproject.toml b/pyproject.toml index 08a599c01..0f75bae58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "detection_rules" -version = "1.5.51" +version = "1.5.52" description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine." readme = "README.md" requires-python = ">=3.12" diff --git a/tests/test_rules_remote.py b/tests/test_rules_remote.py index 749275a8a..3e4cd2fa4 100644 --- a/tests/test_rules_remote.py +++ b/tests/test_rules_remote.py @@ -8,7 +8,13 @@ from copy import deepcopy import pytest -from detection_rules.esql_errors import EsqlSchemaError, EsqlSyntaxError, EsqlTypeMismatchError, EsqlUnknownIndexError +from detection_rules.esql_errors import ( + EsqlSchemaError, + EsqlSemanticError, + EsqlSyntaxError, + EsqlTypeMismatchError, + EsqlUnknownIndexError, +) from detection_rules.misc import ( get_default_config, getdefault, @@ -26,6 +32,59 @@ from .base import BaseRuleTest class TestRemoteRules(BaseRuleTest): """Test rules against a remote Elastic stack instance.""" + def test_get_hashable_content_required_fields_popped_when_keep_star_used(self): + """Hashable content must not contain required_fields when query uses keep * or field wildcards.""" + file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) + original_production_rule = load_rule_contents(file_path) + production_rule = deepcopy(original_production_rule)[0] + # Non-aggregate queries must include _id, _version, _index in keep when keep is not exactly "*" + base = "from logs-aws.cloudtrail* metadata _id, _version, _index\n" + base += '| where event.action == "start"\n | eval Esql.entity_type = cloud.target.entity.type\n | keep ' + keep_star_queries = [ + base + "*", + base + "Esql.*, _id, _version, _index", + base + "host.name, Esql.*, _id, _version, _index", + base + "event.*, _id, _version, _index", + ] + for query in keep_star_queries: + production_rule_copy = deepcopy(production_rule) + production_rule_copy["rule"]["query"] = query + rule = RuleCollection().load_dict(production_rule_copy) + hashable = rule.contents.get_hashable_content() + assert "required_fields" not in hashable, f"required_fields should be popped for keep-star query: {query!r}" + + def test_get_hashable_content_required_fields_kept_when_no_keep_star(self): + """Hashable content keeps required_fields when query uses explicit keep (no wildcards).""" + file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) + original_production_rule = load_rule_contents(file_path) + production_rule = deepcopy(original_production_rule)[0] + production_rule["rule"]["query"] = """ + from logs-aws.cloudtrail* metadata _id, _version, _index + | where event.action == "start" + | keep _id, _version, _index + """ + rule = RuleCollection().load_dict(production_rule) + api = rule.contents.to_api_format() + hashable = rule.contents.get_hashable_content() + if "required_fields" in api: + assert "required_fields" in hashable, "required_fields must not be popped when keep has no wildcards" + + def test_get_hashable_content_required_fields_kept_for_explicit_keep_only(self): + """Hashable content keeps required_fields when keep lists only explicit fields.""" + file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) + original_production_rule = load_rule_contents(file_path) + production_rule = deepcopy(original_production_rule)[0] + production_rule["rule"]["query"] = """ + from logs-aws.cloudtrail* metadata _id, _version, _index + | where event.action == "start" + | keep host.name, user.name, _id, _version, _index + """ + rule = RuleCollection().load_dict(production_rule) + api = rule.contents.to_api_format() + hashable = rule.contents.get_hashable_content() + if "required_fields" in api: + assert "required_fields" in hashable + def test_esql_related_integrations(self): """Test an ESQL rule has its related integrations built correctly.""" file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) @@ -251,3 +310,21 @@ class TestRemoteRules(BaseRuleTest): event.outcome, _id, _version, _index """ _ = RuleCollection().load_dict(production_rule) + + def test_esql_multiple_keeps(self): + """Test an ESQL rule that has multiple keeps in the query.""" + file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) + original_production_rule = load_rule_contents(file_path) + production_rule = deepcopy(original_production_rule)[0] + production_rule["metadata"]["integration"] = ["aws"] + production_rule["rule"]["query"] = """ + from logs-aws.cloudtrail* metadata _id, _version, _index + | where @timestamp > now() - 30 minutes + and event.dataset in ("aws.cloudtrail", "aws.billing") + and aws.cloudtrail.user_identity.type == "IAMUser" + | keep aws.cloudtrail.user_identity.type, _id, _version, _index + | eval Esql.user_type = aws.cloudtrail.user_identity.type + | keep Esql.user_type + """ + with pytest.raises(EsqlSemanticError): + _ = RuleCollection().load_dict(production_rule)