[Bug] Ignore Other Keep Wildcards (#5792)

* Ignore other Keep Wildcards

* Added a unit test for multiple keeps

* Add keep star unit tests
This commit is contained in:
Eric Forte
2026-03-09 19:33:27 -04:00
committed by GitHub
parent 926befff83
commit 26d37dd62e
3 changed files with 94 additions and 16 deletions
+15 -14
View File
@@ -991,8 +991,8 @@ class ESQLRuleData(QueryRuleData):
# Enforce KEEP command for ESQL rules and that METADATA fields are present in non-aggregate queries # Enforce KEEP command for ESQL rules and that METADATA fields are present in non-aggregate queries
# Match | followed by optional whitespace/newlines and then 'keep' # Match | followed by optional whitespace/newlines and then 'keep'
keep_pattern = re.compile(r"\|\s*keep\b\s+([^\|]+)", re.IGNORECASE | re.DOTALL) keep_pattern = re.compile(r"\|\s*keep\b\s+([^\|]+)", re.IGNORECASE | re.DOTALL)
keep_match = keep_pattern.search(query_lower) keep_matches = list(keep_pattern.finditer(query_lower))
if not keep_match: if not keep_matches:
raise EsqlSemanticError( raise EsqlSemanticError(
f"Rule: {data['name']} does not contain a 'keep' command -> Add a 'keep' command to the query." f"Rule: {data['name']} does not contain a 'keep' command -> Add a 'keep' command to the query."
) )
@@ -1000,16 +1000,17 @@ class ESQLRuleData(QueryRuleData):
# Ensure that keep clause includes metadata fields on non-aggregate queries # Ensure that keep clause includes metadata fields on non-aggregate queries
aggregate_pattern = re.compile(r"\|\s*stats\b(?:\s+([^\|]+?))?(?:\s+by\s+([^\|]+))?", re.IGNORECASE | re.DOTALL) aggregate_pattern = re.compile(r"\|\s*stats\b(?:\s+([^\|]+?))?(?:\s+by\s+([^\|]+))?", re.IGNORECASE | re.DOTALL)
if not aggregate_pattern.search(query_lower): if not aggregate_pattern.search(query_lower):
raw_keep = re.sub(r"//.*", "", keep_match.group(1)) for keep_match in keep_matches:
keep_fields = [field.strip() for field in raw_keep.split(",") if field.strip()] raw_keep = re.sub(r"//.*", "", keep_match.group(1))
if "*" not in keep_fields: keep_fields = [field.strip() for field in raw_keep.split(",") if field.strip()]
required_metadata = {"_id", "_version", "_index"} if "*" not in keep_fields:
if not required_metadata.issubset(set(map(str.strip, keep_fields))): required_metadata = {"_id", "_version", "_index"}
raise EsqlSemanticError( if not required_metadata.issubset(set(map(str.strip, keep_fields))):
f"Rule: {data['name']} contains a keep clause without" raise EsqlSemanticError(
f" metadata fields '_id', '_version', and '_index' ->" f"Rule: {data['name']} contains a keep clause without"
f" Add '_id', '_version', '_index' to the keep command." f" metadata fields '_id', '_version', and '_index' ->"
) f" Add '_id', '_version', '_index' to the keep command."
)
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
@@ -1261,7 +1262,7 @@ class BaseRuleContents(ABC):
return obj return obj
def _uses_keep_star(self, hashable_dict: dict[str, Any]) -> bool: def _uses_keep_star(self, hashable_dict: dict[str, Any]) -> bool:
"""Check if this is an ES|QL rule that uses `| keep *`.""" """Check if this is an ES|QL rule that uses `| keep *` or fields ending with '*'."""
if hashable_dict.get("language") != "esql": if hashable_dict.get("language") != "esql":
return False return False
@@ -1273,7 +1274,7 @@ class BaseRuleContents(ABC):
keep_match: re.Match[str] | None = keep_pattern.search(query) keep_match: re.Match[str] | None = keep_pattern.search(query)
if keep_match: if keep_match:
keep_fields: list[str] = [field.strip() for field in keep_match.group(1).split(",")] keep_fields: list[str] = [field.strip() for field in keep_match.group(1).split(",")]
return "*" in keep_fields return any(field == "*" or field.endswith("*") for field in keep_fields)
return False return False
@abstractmethod @abstractmethod
+1 -1
View File
@@ -1,6 +1,6 @@
[project] [project]
name = "detection_rules" name = "detection_rules"
version = "1.5.51" version = "1.5.52"
description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Securitys Detection Engine." description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Securitys Detection Engine."
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
+78 -1
View File
@@ -8,7 +8,13 @@ from copy import deepcopy
import pytest import pytest
from detection_rules.esql_errors import EsqlSchemaError, EsqlSyntaxError, EsqlTypeMismatchError, EsqlUnknownIndexError from detection_rules.esql_errors import (
EsqlSchemaError,
EsqlSemanticError,
EsqlSyntaxError,
EsqlTypeMismatchError,
EsqlUnknownIndexError,
)
from detection_rules.misc import ( from detection_rules.misc import (
get_default_config, get_default_config,
getdefault, getdefault,
@@ -26,6 +32,59 @@ from .base import BaseRuleTest
class TestRemoteRules(BaseRuleTest): class TestRemoteRules(BaseRuleTest):
"""Test rules against a remote Elastic stack instance.""" """Test rules against a remote Elastic stack instance."""
def test_get_hashable_content_required_fields_popped_when_keep_star_used(self):
"""Hashable content must not contain required_fields when query uses keep * or field wildcards."""
file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"])
original_production_rule = load_rule_contents(file_path)
production_rule = deepcopy(original_production_rule)[0]
# Non-aggregate queries must include _id, _version, _index in keep when keep is not exactly "*"
base = "from logs-aws.cloudtrail* metadata _id, _version, _index\n"
base += '| where event.action == "start"\n | eval Esql.entity_type = cloud.target.entity.type\n | keep '
keep_star_queries = [
base + "*",
base + "Esql.*, _id, _version, _index",
base + "host.name, Esql.*, _id, _version, _index",
base + "event.*, _id, _version, _index",
]
for query in keep_star_queries:
production_rule_copy = deepcopy(production_rule)
production_rule_copy["rule"]["query"] = query
rule = RuleCollection().load_dict(production_rule_copy)
hashable = rule.contents.get_hashable_content()
assert "required_fields" not in hashable, f"required_fields should be popped for keep-star query: {query!r}"
def test_get_hashable_content_required_fields_kept_when_no_keep_star(self):
"""Hashable content keeps required_fields when query uses explicit keep (no wildcards)."""
file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"])
original_production_rule = load_rule_contents(file_path)
production_rule = deepcopy(original_production_rule)[0]
production_rule["rule"]["query"] = """
from logs-aws.cloudtrail* metadata _id, _version, _index
| where event.action == "start"
| keep _id, _version, _index
"""
rule = RuleCollection().load_dict(production_rule)
api = rule.contents.to_api_format()
hashable = rule.contents.get_hashable_content()
if "required_fields" in api:
assert "required_fields" in hashable, "required_fields must not be popped when keep has no wildcards"
def test_get_hashable_content_required_fields_kept_for_explicit_keep_only(self):
"""Hashable content keeps required_fields when keep lists only explicit fields."""
file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"])
original_production_rule = load_rule_contents(file_path)
production_rule = deepcopy(original_production_rule)[0]
production_rule["rule"]["query"] = """
from logs-aws.cloudtrail* metadata _id, _version, _index
| where event.action == "start"
| keep host.name, user.name, _id, _version, _index
"""
rule = RuleCollection().load_dict(production_rule)
api = rule.contents.to_api_format()
hashable = rule.contents.get_hashable_content()
if "required_fields" in api:
assert "required_fields" in hashable
def test_esql_related_integrations(self): def test_esql_related_integrations(self):
"""Test an ESQL rule has its related integrations built correctly.""" """Test an ESQL rule has its related integrations built correctly."""
file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"])
@@ -251,3 +310,21 @@ class TestRemoteRules(BaseRuleTest):
event.outcome, _id, _version, _index event.outcome, _id, _version, _index
""" """
_ = RuleCollection().load_dict(production_rule) _ = RuleCollection().load_dict(production_rule)
def test_esql_multiple_keeps(self):
"""Test an ESQL rule that has multiple keeps in the query."""
file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"])
original_production_rule = load_rule_contents(file_path)
production_rule = deepcopy(original_production_rule)[0]
production_rule["metadata"]["integration"] = ["aws"]
production_rule["rule"]["query"] = """
from logs-aws.cloudtrail* metadata _id, _version, _index
| where @timestamp > now() - 30 minutes
and event.dataset in ("aws.cloudtrail", "aws.billing")
and aws.cloudtrail.user_identity.type == "IAMUser"
| keep aws.cloudtrail.user_identity.type, _id, _version, _index
| eval Esql.user_type = aws.cloudtrail.user_identity.type
| keep Esql.user_type
"""
with pytest.raises(EsqlSemanticError):
_ = RuleCollection().load_dict(production_rule)