# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one # or more contributor license agreements. Licensed under the Elastic License # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. import contextlib import os import re from typing import Optional, Set import eql from lark import Token # noqa: F401 from lark import Tree, Lark from lark.exceptions import LarkError, UnexpectedEOF from lark.visitors import Interpreter from kql.errors import KqlParseError from .ast import * # noqa: F403 STRING_FIELDS = ("keyword", "text") class KvTree(Tree): @property def child_trees(self): return [child for child in self.children if isinstance(child, KvTree)] @property def child_tokens(self): return [child for child in self.children if isinstance(child, Token)] grammar_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "kql.g") with open(grammar_file, "rt") as f: grammar = f.read() lark_parser = Lark(grammar, propagate_positions=True, tree_class=KvTree, start=['query'], parser='lalr') def wildcard2regex(wc: str) -> re.Pattern: parts = wc.split("*") return re.compile("^{regex}$".format(regex=".*?".join(re.escape(w) for w in parts))) def elasticsearch_type_family(mapping_type: str) -> str: """Get the family of type for an Elasticsearch mapping type.""" # https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-types.html return { # range types "long_range": "range", "double_range": "range", "date_range": "range", "ip_range": "range", # text search types "annotated-text": "text", "completion": "text", "match_only_text": "text", "search-as_you_type": "text", # keyword "constant_keyword": "keyword", "wildcard": "keyword", # date "date_nanos": "date", # integer "token_count": "integer", "long": "integer", "short": "integer", "byte": "integer", "unsigned_long": "integer", # float "double": "float", "half_float": "float", "scaled_float": "float", }.get(mapping_type, mapping_type) class BaseKqlParser(Interpreter): NON_SPACE_WS = re.compile(r"[^\S ]+") ip_regex = re.compile("^" + eql.functions.CidrMatch.ip_re + "(/([0-2]?[0-9]|3[0-2]))?$") unquoted_escapes = {"\\t": "\t", "\\r": "\r", "\\n": "\n"} for special in "\\():<>\"*{}]": unquoted_escapes["\\" + special] = special unquoted_regex = re.compile("(" + "|".join(re.escape(e) for e in sorted(unquoted_escapes)) + ")") quoted_escapes = {"\\t": "\t", "\\r": "\r", "\\n": "\n", "\\\\": "\\", "\\\"": "\""} quoted_regex = re.compile("(" + "|".join(re.escape(e) for e in sorted(quoted_escapes)) + ")") def __init__(self, text, schema=None): self.text = text self.lines = [t.rstrip("\r\n") for t in self.text.splitlines(True)] self.scoped_field = None self.mapping_schema = schema self.star_fields = [] if schema: for field, field_type in schema.items(): if "*" in field: self.star_fields.append(wildcard2regex(field)) def assert_lower_token(self, *tokens): for token in tokens: if str(token) != str(token).lower(): raise self.error(token, "Expected '{lower}' but got '{token}'".format(token=token, lower=str(token).lower())) def error(self, node, message, end=False, cls=KqlParseError, width=None, **kwargs): """Generate an error exception but dont raise it.""" if kwargs: message = message.format(**kwargs) line_number = node.line - 1 column = node.column - 1 # get more lines for more informative error messages. three before + two after before = self.lines[:line_number + 1][-3:] after = self.lines[line_number + 1:][:3] source = '\n'.join(b for b in before) trailer = '\n'.join(a for a in after) # Determine if the error message can easily look like this # ^^^^ if width is None and not end and node.line == node.end_line: if not self.NON_SPACE_WS.search(self.lines[line_number][column:node.end_column]): width = node.end_column - node.column if width is None: width = 1 return cls(message, line_number, column, source, width=width, trailer=trailer) def __default__(self, tree): raise NotImplementedError("Unable to visit tree {} of type: {}".format(tree, tree.data)) def unescape_literal(self, token): # type: (Token) -> (int|float|str|bool) if token.type == "QUOTED_STRING": return self.convert_quoted_string(token.value) else: return self.convert_unquoted_literal(token.value) @contextlib.contextmanager def scope(self, field): # with self.scope(field) as field: # ... self.scoped_field = field yield field self.scoped_field = None def get_field_type(self, dotted_path, lark_tree=None): matches_pattern = any(regex.match(dotted_path) for regex in self.star_fields) if self.mapping_schema is not None: if lark_tree is not None and dotted_path not in self.mapping_schema and not matches_pattern: raise self.error(lark_tree, "Unknown field") return self.mapping_schema.get(dotted_path) def get_field_types(self, wildcard_dotted_path, lark_tree=None) -> Optional[Set[str]]: if "*" not in wildcard_dotted_path: field_type = self.get_field_type(wildcard_dotted_path, lark_tree=lark_tree) return {field_type} if field_type is not None else None if self.mapping_schema is not None: regex = wildcard2regex(wildcard_dotted_path) field_types = set() for field, field_type in self.mapping_schema.items(): if regex.fullmatch(field) is not None: field_types.add(field_type) if len(field_types) == 0: raise self.error(lark_tree, "Unknown field") return field_types @staticmethod def get_literal_type(literal_value): if isinstance(literal_value, bool): return "boolean" elif isinstance(literal_value, float): return "float" elif isinstance(literal_value, int): return "long" elif eql.utils.is_string(literal_value): # this will be converted when compared to the field return "keyword" elif literal_value is None: return "null" else: raise NotImplementedError("Unknown literal type: {}".format(type(literal_value).__name__)) def convert_value(self, field_name, python_value, value_tree): field_type = None field_types = self.get_field_types(field_name) value_type = self.get_literal_type(python_value) if field_types is not None: if len(field_types) == 1: field_type = list(field_types)[0] elif len(field_types) > 1: raise self.error(value_tree, f"{field_name} has multiple types {', '.join(field_types)}") if field_type is not None and field_type != value_type: field_type_family = elasticsearch_type_family(field_type) if field_type_family in STRING_FIELDS: return eql.utils.to_unicode(python_value) elif field_type_family in ("float", "integer"): try: return float(python_value) if field_type_family == "float" else int(python_value) except ValueError: pass elif field_type_family == "ip" and value_type == "keyword": if "::" in python_value or self.ip_regex.match(python_value) is not None: return python_value elif field_type_family == 'date' and value_type in STRING_FIELDS: # this will not validate datemath syntax return python_value raise self.error(value_tree, "Value doesn't match {field}'s type: {type}", field=field_name, type=field_type) # otherwise, there's nothing to convert return python_value @classmethod def convert_unquoted_literal(cls, text): if text == "true": return True elif text == "false": return False elif text == "null": return None else: for numeric in (int, float): try: return numeric(text) except ValueError: pass text = cls.unquoted_regex.sub(lambda r: cls.unquoted_escapes[r.group()], text) return text @classmethod def convert_quoted_string(cls, text): inner_text = text[1:-1] unescaped = cls.quoted_regex.sub(lambda r: cls.quoted_escapes[r.group()], inner_text) return unescaped class KqlParser(BaseKqlParser): def or_query(self, tree): self.assert_lower_token(*tree.child_tokens) terms = [self.visit(t) for t in tree.child_trees] return OrExpr(terms) def and_query(self, tree): self.assert_lower_token(*tree.child_tokens) terms = [self.visit(t) for t in tree.child_trees] return AndExpr(terms) def not_query(self, tree): self.assert_lower_token(*tree.child_tokens) return NotExpr(self.visit(tree.children[-1])) @contextlib.contextmanager def nest(self, lark_tree): schema = self.mapping_schema dotted_path = self.visit(lark_tree) if self.get_field_type(dotted_path, lark_tree) != "nested": raise self.error(lark_tree, "Expected a nested field") try: self.mapping_schema = self.mapping_schema[dotted_path] yield finally: self.mapping_schema = schema def nested_query(self, tree): # field_tree, query_tree = tree.child_trees # # with self.nest(field_tree) as field: # return NestedQuery(field, self.visit(query_tree)) raise self.error(tree, "Nested queries are not yet supported") def field_value_expression(self, tree): field_tree, expr = tree.child_trees with self.scope(self.visit(field_tree)) as field: # check the field against the schema self.get_field_types(field.name, field_tree) return FieldComparison(field, self.visit(expr)) def field_range_expression(self, tree): field_tree, operator, literal = tree.children with self.scope(self.visit(field_tree)) as field: value = self.convert_value(field.name, self.visit(literal), literal) return FieldRange(field, operator, Value.from_python(value)) def or_list_of_values(self, tree): self.assert_lower_token(*tree.child_tokens) return OrValues([self.visit(t) for t in tree.child_trees]) def and_list_of_values(self, tree): self.assert_lower_token(*tree.child_tokens) return AndValues([self.visit(t) for t in tree.child_trees]) def not_list_of_values(self, tree): self.assert_lower_token(*tree.child_tokens) return NotValue(self.visit(tree.children[-1])) def literal(self, tree): return self.unescape_literal(tree.children[0]) def field(self, tree): literal = self.visit(tree.children[0]) return Field(eql.utils.to_unicode(literal)) def value(self, tree): if self.scoped_field is None: raise self.error(tree, "Value not tied to field") field_name = self.scoped_field.name token = tree.children[0] value = self.unescape_literal(token) if token.type == "UNQUOTED_LITERAL" and "*" in token.value: field_type = self.get_field_type(field_name) if len(value.replace("*", "")) == 0: return Exists() if field_type is not None and field_type not in ("keyword", "wildcard"): raise self.error(tree, "Unable to perform wildcard on field {field} of {type}", field=field_name, type=field_type) return Wildcard(token.value) # try to convert the value to the appropriate type # example: 1 -> "1" if the field is actually keyword value = self.convert_value(field_name, value, tree) return Value.from_python(value) def lark_parse(text): if not text.strip(): raise KqlParseError("No query provided", 0, 0, "") walker = BaseKqlParser(text) try: return lark_parser.parse(text) except UnexpectedEOF: raise KqlParseError("Unexpected EOF", len(walker.lines), len(walker.lines[-1].strip()), walker.lines[-1]) except LarkError as exc: raise KqlParseError("Invalid syntax", exc.line - 1, exc.column - 1, '\n'.join(walker.lines[exc.line - 2:exc.line]))