diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 6b4ee994e..5e3b0fd5f 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -1405,7 +1405,7 @@ def get_unique_query_fields(rule: TOMLRule) -> List[str]: cfg = set_eql_config(rule.contents.metadata.get('min_stack_version')) with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions, eql.parser.skip_optimizations, cfg: - parsed = kql.parse(query) if language == 'kuery' else eql.parse_query(query) + parsed = kql.parse(query, normalize_kql_keywords=True) if language == 'kuery' else eql.parse_query(query) return sorted(set(str(f) for f in parsed if isinstance(f, (eql.ast.Field, kql.ast.Field)))) diff --git a/detection_rules/rule_validators.py b/detection_rules/rule_validators.py index 6631db83e..86d200863 100644 --- a/detection_rules/rule_validators.py +++ b/detection_rules/rule_validators.py @@ -36,7 +36,7 @@ class KQLValidator(QueryValidator): @cached_property def ast(self) -> kql.ast.Expression: - return kql.parse(self.query) + return kql.parse(self.query, normalize_kql_keywords=True) @cached_property def unique_fields(self) -> List[str]: @@ -80,7 +80,7 @@ class KQLValidator(QueryValidator): beats_version, ecs_version) try: - kql.parse(self.query, schema=schema) + kql.parse(self.query, schema=schema, normalize_kql_keywords=True) except kql.KqlParseError as exc: message = exc.error_msg trailer = err_trailer @@ -135,7 +135,7 @@ class KQLValidator(QueryValidator): # Validate the query against the schema try: - kql.parse(self.query, schema=integration_schema) + kql.parse(self.query, schema=integration_schema, normalize_kql_keywords=True) except kql.KqlParseError as exc: if exc.error_msg == "Unknown field": field = extract_error_field(self.query, exc) diff --git a/detection_rules/utils.py b/detection_rules/utils.py index 6bc7e527f..d09f16675 100644 --- a/detection_rules/utils.py +++ b/detection_rules/utils.py @@ -241,7 +241,7 @@ def convert_time_span(span: str) -> int: def evaluate(rule, events): """Evaluate a query against events.""" - evaluator = kql.get_evaluator(kql.parse(rule.query)) + evaluator = kql.get_evaluator(kql.parse(rule.query, normalize_kql_keywords=True)) filtered = list(filter(evaluator, events)) return filtered diff --git a/lib/kql/kql/__init__.py b/lib/kql/kql/__init__.py index 2c6b0ef23..66c91ca9b 100644 --- a/lib/kql/kql/__init__.py +++ b/lib/kql/kql/__init__.py @@ -45,12 +45,12 @@ def to_eql(text, optimize=True, schema=None): return converted.optimize(recursive=True) if optimize else converted -def parse(text, optimize=True, schema=None): +def parse(text, optimize: bool = True, schema: dict = None, normalize_kql_keywords: bool = False): if isinstance(text, bytes): text = text.decode("utf-8") lark_parsed = lark_parse(text) - converted = KqlParser(text, schema=schema).visit(lark_parsed) + converted = KqlParser(text, schema=schema, normalize_kql_keywords=normalize_kql_keywords).visit(lark_parsed) return converted.optimize(recursive=True) if optimize else converted diff --git a/lib/kql/kql/parser.py b/lib/kql/kql/parser.py index e3017f2fc..c36403be4 100644 --- a/lib/kql/kql/parser.py +++ b/lib/kql/kql/parser.py @@ -104,22 +104,29 @@ class BaseKqlParser(Interpreter): quoted_escapes = {"\\t": "\t", "\\r": "\r", "\\n": "\n", "\\\\": "\\", "\\\"": "\""} quoted_regex = re.compile("(" + "|".join(re.escape(e) for e in sorted(quoted_escapes)) + ")") - def __init__(self, text, schema=None): + def __init__(self, text: str, schema: dict = None, normalize_kql_keywords: bool = True) -> None: + """Initialize the parser. Defaults to normalizing KQL keywords to lowercase.""" self.text = text self.lines = [t.rstrip("\r\n") for t in self.text.splitlines(True)] self.scoped_field = None self.mapping_schema = schema self.star_fields = [] + self.normalize_kql_keywords = normalize_kql_keywords if schema: for field, field_type in schema.items(): if "*" in field: self.star_fields.append(wildcard2regex(field)) - def assert_lower_token(self, *tokens): + def assert_lower_token(self, *tokens: Token) -> None: + """Assert that the token is lowercase and converts token if not.""" for token in tokens: - if str(token) != str(token).lower(): - raise self.error(token, "Expected '{lower}' but got '{token}'".format(token=token, lower=str(token).lower())) + lower_token = str(token).lower() + if str(token) != lower_token: + if self.normalize_kql_keywords: + token.value = lower_token + else: + raise self.error(token, f"Expected '{lower_token}' but got '{token}'") def error(self, node, message, end=False, cls=KqlParseError, width=None, **kwargs): """Generate an error exception but dont raise it.""" diff --git a/tests/kuery/test_lint.py b/tests/kuery/test_lint.py index 7f0e97bd1..31953cbd7 100644 --- a/tests/kuery/test_lint.py +++ b/tests/kuery/test_lint.py @@ -34,6 +34,12 @@ class LintTests(unittest.TestCase): with self.assertRaises(kql.KqlParseError): kql.parse(q) + for q in queries: + # Test query successfully converts and parses + parsed_query = kql.parse(q, normalize_kql_keywords=True) + # Test that the parsed query is not equal to the original query, that the transformation was applied + self.assertNotEqual(str(parsed_query), q, f"Parsed query {parsed_query} matches the original {q}") + def test_lint_precedence(self): self.validate("a:b or (c:d and e:f)", "a:b or c:d and e:f") self.validate("(a:b and (c:d or e:f))", "a:b and (c:d or e:f)") diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index c66563c34..39fe344cb 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -67,7 +67,7 @@ class TestValidRules(BaseRuleTest): ) ): source = rule.contents.data.query - tree = kql.parse(source, optimize=False) + tree = kql.parse(source, optimize=False, normalize_kql_keywords=True) optimized = tree.optimize(recursive=True) err_message = f'\n{self.rule_str(rule)} Query not optimized for rule\n' \ f'Expected: {optimized}\nActual: {source}'