From 41809f1dc5f15ac92657241c0716619a87aa37da Mon Sep 17 00:00:00 2001 From: Ross Wolf <31489089+rw-access@users.noreply.github.com> Date: Mon, 29 Jun 2020 23:05:14 -0600 Subject: [PATCH] Add KQL module --- kql/__init__.py | 60 ++++++++++ kql/ast.py | 244 ++++++++++++++++++++++++++++++++++++++++ kql/eql2kql.py | 121 ++++++++++++++++++++ kql/errors.py | 17 +++ kql/evaluator.py | 151 +++++++++++++++++++++++++ kql/kql.g | 51 +++++++++ kql/kql2eql.py | 100 +++++++++++++++++ kql/optimizer.py | 129 +++++++++++++++++++++ kql/parser.py | 285 +++++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 1158 insertions(+) create mode 100644 kql/__init__.py create mode 100644 kql/ast.py create mode 100755 kql/eql2kql.py create mode 100644 kql/errors.py create mode 100644 kql/evaluator.py create mode 100644 kql/kql.g create mode 100755 kql/kql2eql.py create mode 100644 kql/optimizer.py create mode 100644 kql/parser.py diff --git a/kql/__init__.py b/kql/__init__.py new file mode 100644 index 000000000..b3af0e336 --- /dev/null +++ b/kql/__init__.py @@ -0,0 +1,60 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import eql + +from . import ast +from .eql2kql import Eql2Kql +from .errors import KqlParseError, KqlCompileError +from .evaluator import FilterGenerator +from .kql2eql import KqlToEQL +from .parser import lark_parse, KqlParser + +__version__ = '0.1.4' +__all__ = ( + "ast", + "to_eql", + "lint", + "parse", + "from_eql", + "get_evaluator", + "KqlParseError", + "KqlCompileError", +) + + +def to_eql(text, optimize=True, schema=None): + lark_parsed = lark_parse(text) + + converted = KqlToEQL(text, schema=schema).visit(lark_parsed) + return converted.optimize(recursive=True) if optimize else converted + + +def parse(text, optimize=True, schema=None): + lark_parsed = lark_parse(text) + converted = KqlParser(text, schema=schema).visit(lark_parsed) + + return converted.optimize(recursive=True) if optimize else converted + + +def lint(text): + return parse(text, optimize=True).render() + + +def from_eql(tree, optimize=True): + if not isinstance(tree, eql.ast.EqlNode): + try: + tree = eql.parse_query(tree, implied_any=True) + except eql.EqlSemanticError: + tree = eql.parse_expression(tree) + + converted = Eql2Kql().walk(tree) + return converted.optimize(recursive=True) if optimize else converted + + +def get_evaluator(tree, optimize=False): + if not isinstance(tree, ast.KqlNode): + tree = parse(tree, optimize=optimize) + + return FilterGenerator().filter(tree) diff --git a/kql/ast.py b/kql/ast.py new file mode 100644 index 000000000..33d35ddce --- /dev/null +++ b/kql/ast.py @@ -0,0 +1,244 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import re +from string import Template + +from eql.ast import BaseNode +from eql.errors import EqlCompileError +from eql.utils import is_number, is_string + +__all__ = ( + "KqlNode", + "Value", + "Null", + "Number", + "Boolean", + "List", + "Expression", + "String", + "Wildcard", + "NotValue", + "OrValues", + "AndValues", + "AndExpr", + "OrExpr", + "NotExpr", + "FieldComparison", + "Field", + "FieldRange", + "NestedQuery", + "Exists", +) + + +class KqlNode(BaseNode): + def optimize(self, recursive=True): + from .optimizer import Optimizer + return Optimizer().walk(self) + + def _render(self): + return BaseNode.render(self) + + def render(self, precedence=None, **kwargs): + """Render an EQL node and add parentheses to support orders of operation.""" + rendered = self._render(**kwargs) + if precedence is not None and self.precedence is not None and self.precedence > precedence: + return '({})'.format(rendered) + return rendered + + +class Value(KqlNode): + __slots__ = "value", + precedence = 1 + + def __init__(self, value): + self.value = value + + @classmethod + def from_python(cls, value): + if value is None: + return Null() + elif isinstance(value, bool): + return Boolean(value) + elif is_number(value): + return Number(value) + elif is_string(value): + return String(value) + else: + raise EqlCompileError("Unknown type {} for value {}".format(type(value).__name__, value)) + + +class Null(Value): + def __init__(self, value=None): + Value.__init__(self, None) + + def _render(self): + return "null" + + +class Number(Value): + def _render(self): + return str(self.value) + + +class Boolean(Value): + def _render(self): + return 'true' if self.value else 'false' + + +class String(Value): + unescapable = re.compile(r'^[^\\():<>"*{} \t\r\n]+$') + escapes = {"\t": "\\t", "\r": "\\r", "\"": "\\\""} + + def _render(self): + # pass through as-is since nothing needs to be escaped + if self.unescapable.match(self.value) is not None: + return str(self.value) + + regex = r"[{}]".format("".join(re.escape(s) for s in sorted(self.escapes))) + return '"{}"'.format(re.sub(regex, lambda r: self.escapes[r.group()], self.value)) + + +class Wildcard(Value): + escapes = {"\t": "\\t", "\r": "\\r"} + slash_escaped = r'''^\\():<>"*{} ''' + + def _render(self): + escaped = [] + for char in self.value: + if char in self.slash_escaped: + escaped.append("\\") + escaped.append(char) + elif char in self.escapes: + escaped.append(self.escapes[char]) + else: + escaped.append(char) + return ''.join(escaped) + + +class List(KqlNode): + __slots__ = "items", + precedence = Value.precedence + 1 + operator = "" + template = Template("$items") + + def __init__(self, items): + self.items = items + KqlNode.__init__(self) + + @property + def delims(self): + return {"items": " {} ".format(self.operator)} + + def __eq__(self, other): + from .optimizer import Optimizer + from functools import cmp_to_key + if type(self) == type(other): + a = list(self.items) + b = list(other.items) + a.sort(key=cmp_to_key(Optimizer.sort_key)) + b.sort(key=cmp_to_key(Optimizer.sort_key)) + return a == b + + return False + + +class NotValue(KqlNode): + __slots__ = "value", + template = Template("not $value") + precedence = Value.precedence + 1 + + def __init__(self, value): + self.value = value + KqlNode.__init__(self) + + +class AndValues(List): + precedence = List.precedence + 1 + operator = "and" + + +class OrValues(List): + precedence = AndValues.precedence + 1 + operator = "or" + + +class Field(KqlNode): + __slots__ = "name", + precedence = Value.precedence + template = Template("$name") + + def __init__(self, name): + self.name = name + KqlNode.__init__(self) + + @property + def path(self): + return self.name.split(".") + + @classmethod + def from_path(cls, path): + dotted = ".".join(path) + return cls(dotted) + + +class Expression(KqlNode): + """Intermediate node for class hierarchy.""" + + +class FieldRange(Expression, KqlNode): + __slots__ = "field", "operator", "value", + precedence = Field.precedence + template = Template("$field $operator $value") + + def __init__(self, field, operator, value): + self.field = field + self.operator = operator + self.value = value + + +class NestedQuery(Expression): + __slots__ = "field", "expr", + precedence = Field.precedence + 1 + template = Template("$field:{$expr}") + + def __init__(self, field, expr): + self.field = field + self.expr = expr + + +class FieldComparison(Expression): + __slots__ = "field", "value", + precedence = FieldRange.precedence + template = Template("$field:$value") + + def __init__(self, field, value): + self.field = field + self.value = value + + +class Exists(KqlNode): + __slots__ = tuple() + precedence = FieldComparison.precedence + template = Template("*") + + +class NotExpr(Expression): + __slots__ = "expr", + precedence = FieldComparison.precedence + 1 + template = Template("not $expr") + + def __init__(self, expr): + self.expr = expr + + +class AndExpr(Expression, List): + precedence = NotExpr.precedence + 1 + operator = "and" + + +class OrExpr(Expression, List): + precedence = AndExpr.precedence + 1 + operator = "or" diff --git a/kql/eql2kql.py b/kql/eql2kql.py new file mode 100755 index 000000000..9d139fc6f --- /dev/null +++ b/kql/eql2kql.py @@ -0,0 +1,121 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +#!/usr/bin/env python +import eql +from eql import DepthFirstWalker + +from .ast import ( + Value, String, OrValues, Field, Expression, FieldRange, FieldComparison, + NotExpr, AndExpr, OrExpr, Exists +) + + +class Eql2Kql(DepthFirstWalker): + + def _walk_default(self, tree, *args, **kwargs): + if isinstance(tree, eql.ast.EqlNode): + raise eql.errors.EqlCompileError("Unable to convert {}".format(tree)) + else: + return tree + + def check_field_expression(self, tree): + if not isinstance(tree, Expression): + raise eql.errors.EqlCompileError("Expected expression, but got {}".format(repr(tree))) + return tree + + def check_field_expressions(self, trees): + for tree in trees: + self.check_field_expression(tree) + return trees + + def _walk_and(self, tree): # type: (eql.ast.And) -> AndExpr + return AndExpr(self.check_field_expressions(tree.terms)) + + def _walk_or(self, tree): # type: (eql.ast.Or) -> OrExpr + return OrExpr(self.check_field_expressions(tree.terms)) + + def _walk_not(self, tree): # type: (eql.ast.Not) -> NotExpr + return NotExpr(self.check_field_expression(tree.term)) + + def _walk_is_null(self, node): # type: (eql.ast.IsNull) -> FieldComparison + if not isinstance(node.expr, Field): + raise eql.errors.EqlCompileError("Unable to compare a non-field [{}] to null".format(node.expr)) + + return NotExpr(FieldComparison(node.expr, Exists())) + + def _walk_is_not_null(self, node): # type: (eql.ast.IsNotNull) -> Expression + if not isinstance(node.expr, Field): + raise eql.errors.EqlCompileError("Unable to compare a non-field [{}] to null".format(node.expr)) + + return FieldComparison(node.expr, Exists()) + + def _walk_field(self, tree): # type: (eql.ast.Field) -> Field + if any(eql.utils.is_number(n) for n in tree.path): + raise eql.errors.EqlCompileError("Unable to convert array field: {}".format(tree)) + + return Field(tree.render()) + + def _walk_in_set(self, tree): # type: (eql.ast.InSet) -> FieldComparison + if not isinstance(tree.expression, Field) or not all(isinstance(v, Value) for v in tree.container): + raise eql.errors.EqlCompileError("Unable to convert `{}`".format(tree.expression, tree)) + + return FieldComparison(tree.expression, OrValues(tree.container)) + + def _walk_function_call(self, tree): # type: (eql.ast.FunctionCall) -> KqlNode + if tree.name in ("wildcard", "cidrMatch"): + if isinstance(tree.arguments[0], Field): + return FieldComparison(tree.arguments[0], OrValues(tree.arguments[1:])) + + raise eql.errors.EqlCompileError("Unable to convert `{}`".format(tree)) + + def _walk_literal(self, tree): + return Value.from_python(tree.value) + + def _walk_event_query(self, tree): # type: (eql.ast.EventQuery) -> KqlNode + if tree.event_type == eql.schema.EVENT_TYPE_ANY: + return self.check_field_expression(tree.query) + + event_check = FieldComparison(Field("event.category"), String(tree.event_type)) + + # for `x where true` shorthand, drop the `where true` + if tree.query == Value.from_python(True): + return event_check + + self.check_field_expression(tree.query) + return AndExpr([event_check, tree.query]) + + def _walk_filter_pipe(self, tree): # type: (eql.pipes.FilterPipe) -> KqlNode + return self.check_field_expression(tree.expression) + + def _walk_piped_query(self, tree): # type: (eql.ast.PipedQuery) -> KqlNode + if not tree.pipes: + return tree.first + + return AndExpr([tree.first] + tree.pipes) + + LT, LE, EQ, NE, GE, GT = ('<', '<=', '==', '!=', '>=', '>') + flipped = {LT: GE, LE: GT, + EQ: EQ, NE: NE, + GE: LT, GT: LE} + + def _walk_comparison(self, tree): # type: (eql.ast.Comparison) -> KqlNode + left = tree.left + op = tree.comparator + right = tree.right + + # move the literal to the right + if isinstance(left, eql.ast.Literal): + left, right = right, left + op = self.flipped[op] + + if isinstance(left, Field) and isinstance(right, Value): + if op == eql.ast.Comparison.EQ: + return FieldComparison(left, right) + elif op == eql.ast.Comparison.NE: + return NotExpr(FieldComparison(left, right)) + else: + return FieldRange(left, op, right) + + raise eql.errors.EqlCompileError("Unable to convert {}".format(tree)) diff --git a/kql/errors.py b/kql/errors.py new file mode 100644 index 000000000..8530c8109 --- /dev/null +++ b/kql/errors.py @@ -0,0 +1,17 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +from eql import EqlError, EqlParseError, EqlCompileError + + +class KqlParseError(EqlParseError): + """EQL Parsing Error.""" + + +class KqlCompileError(EqlCompileError): + """Class for KQL-specific compile errors.""" + + +class KqlRuntimeError(EqlError): + """Error for failures within the KQL evaluator.""" diff --git a/kql/evaluator.py b/kql/evaluator.py new file mode 100644 index 000000000..47b74db8b --- /dev/null +++ b/kql/evaluator.py @@ -0,0 +1,151 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import operator +import re + +import eql.ast +from eql import Walker, EqlCompileError, utils +from eql.functions import CidrMatch +from .errors import KqlRuntimeError, KqlCompileError + + +class FilterGenerator(Walker): + __cidr_cache = {} + + def _walk_default(self, node, *args, **kwargs): + raise KqlCompileError("Unable to convert {}".format(node)) + + @classmethod + def equals(cls, term, value): + if utils.is_string(term) and utils.is_string(value): + if CidrMatch.ip_compiled.match(term) and CidrMatch.cidr_compiled.match(value): + # check for an ipv4 cidr + if value not in cls.__cidr_cache: + cls.__cidr_cache[value] = CidrMatch.get_callback(None, eql.ast.String(value)) + return cls.__cidr_cache[value](term) + + return term == value + + @classmethod + def get_terms(cls, document, path): + if isinstance(document, (tuple, list)): + for d in document: + yield from cls.get_terms(d, path) + + elif isinstance(document, dict): + document = document.get(path[0]) + path = path[1:] + + if len(path) > 0: + yield from cls.get_terms(document, path) + elif isinstance(document, (tuple, list)): + yield from iter(document) + elif document is not None: + yield document + + def _walk_value(self, tree, compare_function=None): + value = tree.value + compare_function = compare_function or self.equals + + def check_value(term): + if term is None: + return False + + if isinstance(term, list): + return any(check_value(t) for t in term) + + if isinstance(term, (bool, float, int)) or utils.is_string(term): + v = value + + if utils.is_string(v) and isinstance(term, (bool, int, float)): + if isinstance(v, bool): + v = v == "false" + if isinstance(term, int): + v = int(v) + elif isinstance(v, float): + v = float(v) + + elif utils.is_string(term) and isinstance(v, (bool, int, float)): + v = utils.to_unicode(v) + + return compare_function(term, v) + else: + raise KqlRuntimeError("Cannot compare value {}".format(term)) + + return check_value + + def _walk_exists(self, _): + return lambda terms: any(t is not None for t in terms) + + def _walk_wildcard(self, tree): + pattern = tree.value + regex = re.compile(".*?".join(map(re.escape, pattern.split("*"))), re.UNICODE | re.DOTALL) + return lambda terms: any(t is not None and regex.fullmatch(t) for t in terms) + + def _walk_field(self, field): + path = field.name.split(".") + get_terms = self.get_terms + + def callback(document): + terms = get_terms(document, path) + terms = list(terms) + return terms + + return callback + + def _walk_field_range(self, tree): + field = self.walk(tree.field) + operators = {"<": operator.lt, "<=": operator.le, ">=": operator.ge, ">": operator.gt} + + check_range = self.walk(tree.value, operators[tree.operator]) + return lambda doc: check_range(field(doc)) + + def _walk_nested_query(self, tree): + field = self.walk(tree.field) + expr = self.walk(tree.expr) + + def check_nested(doc): + doc = field(doc) + + if isinstance(doc, dict): + return expr(doc) + elif isinstance(doc, (list, tuple)): + return any(expr(d) for d in doc) + + return check_nested + + def _walk_list(self, trees, reduce_function, *args, **kwargs): + walked = [self.walk(item, *args, **kwargs) for item in trees.items] + return lambda x: reduce_function(item(x) for item in walked) + + def _walk_not_expr(self, tree): + expr = self.walk(tree.expr) + return lambda doc: not expr(doc) + + def _walk_and_expr(self, tree): + return self._walk_list(tree, all) + + def _walk_or_expr(self, tree): + return self._walk_list(tree, any) + + def _walk_and_values(self, tree): + return self._walk_list(tree, all) + + def _walk_or_values(self, tree): + return self._walk_list(tree, any) + + def _walk_not_value(self, tree): + expr = self.walk(tree.value) + return lambda value: not expr(value) + + def _walk_field_comparison(self, tree): + field = self.walk(tree.field) + value = self.walk(tree.value) + + return lambda doc: value(field(doc)) + + @classmethod + def filter(cls, expression): + return cls().walk(expression) diff --git a/kql/kql.g b/kql/kql.g new file mode 100644 index 000000000..8f2cec115 --- /dev/null +++ b/kql/kql.g @@ -0,0 +1,51 @@ +?query: or_query +?or_query: and_query (OR and_query)* +?and_query: not_query (AND not_query)* +?not_query: NOT? sub_query +?sub_query: "(" or_query ")" + | nested_query +?nested_query: field ":" "{" or_query "}" + | expression +?expression: field_range_expression + | field_value_expression + | value_expression + +field_range_expression: field RANGE_OPERATOR literal +field_value_expression: field ":" list_of_values +?value_expression: value + +?list_of_values: "(" or_list_of_values ")" + | value +?or_list_of_values: and_list_of_values (OR and_list_of_values)* +?and_list_of_values: not_list_of_values (AND not_list_of_values)* +?not_list_of_values: NOT? list_of_values + +field: literal + +value: QUOTED_STRING + | UNQUOTED_LITERAL + + +literal: QUOTED_STRING + | UNQUOTED_LITERAL + +RANGE_OPERATOR: "<=" + | ">=" + | "<" + | ">" + +UNQUOTED_LITERAL: UNQUOTED_CHAR+ +UNQUOTED_CHAR: "\\" /[trn]/ // escaped whitespace + | "\\" /[\\():<>"*{}]/ // escaped specials + | "\\" (AND | OR | NOT) // escaped keywords + | "*" // wildcard + | /[^\\():<>"*{} \t\r\n]/ // anything else + +QUOTED_STRING: /"(\\[tnr"\\]|[^\r\n"])*"/ + +OR.2: "or" | "OR" +AND.2: "and" | "AND" +NOT.2: "not" | "NOT" + +WHITESPACE: (" " | "\r" | "\n" | "\t" )+ +%ignore WHITESPACE \ No newline at end of file diff --git a/kql/kql2eql.py b/kql/kql2eql.py new file mode 100755 index 000000000..96efbfad8 --- /dev/null +++ b/kql/kql2eql.py @@ -0,0 +1,100 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +#!/usr/bin/env python + +import eql + +from .parser import BaseKqlParser + + +class KqlToEQL(BaseKqlParser): + + # + # Lark Visit methods + # + @staticmethod + def to_eql_field(name): + path = name.split(".") + return eql.ast.Field(path[0], path[1:]) + + def or_query(self, tree): + terms = [self.visit(t) for t in tree.child_trees] + return eql.ast.Or(terms) + + def and_query(self, tree): + terms = [self.visit(t) for t in tree.child_trees] + return eql.ast.And(terms) + + def not_query(self, tree): + return eql.ast.Not(self.visit(tree.children[-1])) + + def nested_query(self, tree): + raise self.error(tree, "Unable to convert nested query to EQL") + + def field_range_expression(self, tree): + field_tree, operator, literal_tree = tree.children + field_name = self.visit(field_tree) + + # check the field against the schema + self.get_field_type(field_name, field_tree) + + # get and convert the value + value = self.convert_value(field_name, self.visit(literal_tree), literal_tree) + literal = eql.ast.Literal.from_python(value) + + field = self.to_eql_field(field_name) + return eql.ast.Comparison(field, operator.value, literal) + + def field_value_expression(self, tree): + field_tree, value_tree = tree.child_trees + + with self.scope(self.visit(field_tree)) as field_name: + # check the field against the schema + self.get_field_type(field_name, field_tree) + return self.visit(value_tree) + + def or_list_of_values(self, tree): + children = [self.visit(t) for t in tree.child_trees] + return eql.ast.Or(children) + + def and_list_of_values(self, tree): + children = [self.visit(t) for t in tree.child_trees] + return eql.ast.And(children) + + def not_list_of_values(self, tree): + return eql.ast.Not(self.visit(tree.children[-1])) + + def field(self, tree): + literal = self.visit(tree.children[0]) + return eql.utils.to_unicode(literal) + + def value(self, tree): + # TODO: check the logic for kuery.peg + value = self.unescape_literal(tree.children[0]) + + if self.scoped_field is None: + raise self.error(tree, "Value not tied to field") + + field_name = self.scoped_field + field = self.to_eql_field(field_name) + value = self.convert_value(field_name, value, tree) + value_ast = eql.ast.Literal.from_python(value) + + if value is None: + return eql.ast.IsNull(field) + + if eql.utils.is_string(value) and value.replace("*", "") == "": + return eql.ast.IsNotNull(field) + + if eql.utils.is_string(value) and "*" in value: + return eql.ast.FunctionCall("wildcard", [field, value_ast]) + + if self.get_field_type(field_name) == "ip" and "/" in value: + return eql.ast.FunctionCall("cidrMatch", [field, value_ast]) + + return eql.ast.Comparison(field, "==", value_ast) + + def literal(self, tree): + return self.unescape_literal(tree.children[0]) diff --git a/kql/optimizer.py b/kql/optimizer.py new file mode 100644 index 000000000..0b71cc729 --- /dev/null +++ b/kql/optimizer.py @@ -0,0 +1,129 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import functools + +from eql import Walker, DepthFirstWalker + +from .ast import AndValues, NotValue, Value, OrValues, NotExpr, FieldComparison + + +class Optimizer(DepthFirstWalker): + + def flat_optimize(self, tree): + return Walker.walk(self, tree) + + def _walk_default(self, tree, *args, **kwargs): + return tree + + def group_fields(self, tree, value_cls): # type: (List, type) -> KqlNode + cls = type(tree) + field_groups = {} + ungrouped = [] + + for term in tree.items: + # move a `not` inwards before grouping + if isinstance(term, NotExpr) and isinstance(term.expr, FieldComparison): + term = FieldComparison(term.expr.field, NotValue(term.expr.value)) + + if isinstance(term, FieldComparison): + if term.field.name in field_groups: + existing_checks = field_groups[term.field.name] + existing_checks.append(term) + continue + else: + field_groups[term.field.name] = [term] + + ungrouped.append(term) + + for term in ungrouped: + if isinstance(term, FieldComparison): + term.value = self.flat_optimize(value_cls([t.value for t in field_groups[term.field.name]])) + + ungrouped = [self.flat_optimize(u) for u in ungrouped] + return cls(ungrouped) if len(ungrouped) > 1 else ungrouped[0] + + @staticmethod + def sort_key(a, b): + if isinstance(a, Value) and not isinstance(b, Value): + return -1 + if not isinstance(a, Value) and isinstance(b, Value): + return +1 + + if isinstance(a, Value) and isinstance(b, Value): + t_a = type(a) + t_b = type(b) + + if t_a == t_b: + return (a.value > b.value) - (a.value < b.value) + else: + return (t_a.__name__ > b.__name__) - (a.__name__ < b.__name__) + + else: + # unable to compare + return 0 + + def _walk_field_comparison(self, tree): # type: (FieldComparison) -> KqlNode + # if there's a single `not`, then pull it out of the expression + if isinstance(tree.value, NotValue): + return NotExpr(FieldComparison(tree.field, tree.value.value)) + return tree + + def flatten(self, tree): # type: (List) -> List + cls = type(tree) + flattened = [] + for node in tree.items: + if isinstance(node, cls): + flattened.extend(node.items) + else: + flattened.append(node) + + flattened = [self.flat_optimize(t) for t in flattened] + return cls(flattened) + + def flatten_values(self, tree, dual_cls): # type: (List, type) -> List + cls = type(tree) + flattened = [] + not_term = None + + for term in self.flatten(tree).items: + if isinstance(term, NotValue) and isinstance(term.value, Value): + # create a copy to leave the source tree unaltered + term = NotValue(term.value) + if not_term is None: + not_term = term + else: + not_term.value = dual_cls([not_term.value, term.value]) + continue + + flattened.append(term) + + if not_term is not None: + not_term.value = self.flat_optimize(not_term.value) + + flattened = [self.flat_optimize(t) for t in flattened] + flattened.sort(key=functools.cmp_to_key(self.sort_key)) + return cls(flattened) if len(flattened) > 1 else flattened[0] + + def _walk_not_value(self, tree): # type: (NotValue) -> KqlNode + if isinstance(tree.value, NotValue): + return tree.value.value + return tree + + def _walk_or_values(self, tree): + return self.flatten_values(tree, AndValues) + + def _walk_and_values(self, tree): + return self.flatten_values(tree, OrValues) + + def _walk_not_expr(self, tree): # type: (NotExpr) -> KqlNode + if isinstance(tree.expr, NotExpr): + return tree.expr.expr + return tree + + def _walk_and_expr(self, tree): # type: (AndExpr) -> KqlNode + return self.group_fields(self.flatten(tree), value_cls=AndValues) + + def _walk_or_expr(self, tree): # type: (OrExpr) -> KqlNode + return self.group_fields(self.flatten(tree), value_cls=OrValues) diff --git a/kql/parser.py b/kql/parser.py new file mode 100644 index 000000000..40f362423 --- /dev/null +++ b/kql/parser.py @@ -0,0 +1,285 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import contextlib +import os +import re + +import eql +from lark import Token # noqa: F401 +from lark import Tree, Lark +from lark.exceptions import LarkError, UnexpectedEOF +from lark.visitors import Interpreter + +from kql.errors import KqlParseError +from .ast import * # noqa: F403 + + +STRING_FIELDS = ("keyword", "text") + + +class KvTree(Tree): + + @property + def child_trees(self): + return [child for child in self.children if isinstance(child, KvTree)] + + @property + def child_tokens(self): + return [child for child in self.children if isinstance(child, Token)] + + +grammar_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "kql.g") +with open(grammar_file, "rt") as f: + grammar = f.read() + +lark_parser = Lark(grammar, propagate_positions=True, tree_class=KvTree, start=['query'], parser='lalr') + + +class BaseKqlParser(Interpreter): + NON_SPACE_WS = re.compile(r"[^\S ]+") + ip_regex = re.compile("^" + eql.functions.CidrMatch.ip_re + "(/([0-2]?[0-9]|3[0-2]))?$") + + unquoted_escapes = {"\\t": "\t", "\\r": "\r", "\\n": "\n"} + + for special in "\\():<>\"*{}]": + unquoted_escapes["\\" + special] = special + + unquoted_regex = re.compile("(" + "|".join(re.escape(e) for e in sorted(unquoted_escapes)) + ")") + + quoted_escapes = {"\\t": "\t", "\\r": "\r", "\\n": "\n", "\\\\": "\\", "\\\"": "\""} + quoted_regex = re.compile("(" + "|".join(re.escape(e) for e in sorted(quoted_escapes)) + ")") + + def __init__(self, text, schema=None): + self.text = text + self.lines = [t.rstrip("\r\n") for t in self.text.splitlines(True)] + self.scoped_field = None + self.schema = schema + + def assert_lower_token(self, *tokens): + for token in tokens: + if str(token) != str(token).lower(): + raise self.error(token, "Expected '{lower}' but got '{token}'".format(token=token, lower=str(token).lower())) + + def error(self, node, message, end=False, cls=KqlParseError, width=None, **kwargs): + """Generate.""" + if kwargs: + message = message.format(**kwargs) + + line_number = node.line - 1 + column = node.column - 1 + + # get more lines for more informative error messages. three before + two after + before = self.lines[:line_number + 1][-3:] + after = self.lines[line_number + 1:][:3] + + source = '\n'.join(b for b in before) + trailer = '\n'.join(a for a in after) + + # Determine if the error message can easily look like this + # ^^^^ + if width is None and not end and node.line == node.end_line: + if not self.NON_SPACE_WS.search(self.lines[line_number][column:node.end_column]): + width = node.end_column - node.column + + if width is None: + width = 1 + + return cls(message, line_number, column, source, width=width, trailer=trailer) + + def __default__(self, tree): + raise NotImplementedError("Unable to visit tree {} of type: {}".format(tree, tree.data)) + + def unescape_literal(self, token): # type: (Token) -> (int|float|str|bool) + if token.type == "QUOTED_STRING": + return self.convert_quoted_string(token.value) + else: + return self.convert_unquoted_literal(token.value) + + @contextlib.contextmanager + def scope(self, field): + # with self.scope(field) as field: + # ... + self.scoped_field = field + yield field + self.scoped_field = None + + def get_field_type(self, dotted_path, lark_tree=None): + if self.schema is not None: + if lark_tree is not None and dotted_path not in self.schema: + raise self.error(lark_tree, "Unknown field") + + return self.schema[dotted_path] + + @staticmethod + def get_literal_type(literal_value): + if isinstance(literal_value, bool): + return "boolean" + elif isinstance(literal_value, float): + return "float" + elif isinstance(literal_value, int): + return "long" + elif eql.utils.is_string(literal_value): + # this will be converted when compared to the field + return "keyword" + elif literal_value is None: + return "null" + else: + raise NotImplementedError("Unknown literal type: {}".format(type(literal_value).__name__)) + + def convert_value(self, field_name, python_value, value_tree): + field_type = self.get_field_type(field_name) + value_type = self.get_literal_type(python_value) + + if field_type is not None and field_type != value_type: + if field_type in STRING_FIELDS: + return eql.utils.to_unicode(python_value) + elif field_type in ("float", "long"): + try: + return float(python_value) if field_type == "float" else int(python_value) + except ValueError: + pass + elif field_type == "ip" and value_type == "keyword": + if "::" in python_value or self.ip_regex.match(python_value) is not None: + return python_value + + raise self.error(value_tree, "Value doesn't match {field}'s type: {type}", + field=field_name, type=field_type) + + # otherwise, there's nothing to convert + return python_value + + @classmethod + def convert_unquoted_literal(cls, text): + if text == "true": + return True + elif text == "false": + return False + elif text == "null": + return None + else: + for numeric in (int, float): + try: + return numeric(text) + except ValueError: + pass + + text = cls.unquoted_regex.sub(lambda r: cls.unquoted_escapes[r.group()], text) + return text + + @classmethod + def convert_quoted_string(cls, text): + inner_text = text[1:-1] + unescaped = cls.quoted_regex.sub(lambda r: cls.quoted_escapes[r.group()], inner_text) + return unescaped + + +class KqlParser(BaseKqlParser): + def or_query(self, tree): + self.assert_lower_token(*tree.child_tokens) + terms = [self.visit(t) for t in tree.child_trees] + return OrExpr(terms) + + def and_query(self, tree): + self.assert_lower_token(*tree.child_tokens) + terms = [self.visit(t) for t in tree.child_trees] + return AndExpr(terms) + + def not_query(self, tree): + self.assert_lower_token(*tree.child_tokens) + return NotExpr(self.visit(tree.children[-1])) + + @contextlib.contextmanager + def nest(self, lark_tree): + schema = self.schema + dotted_path = self.visit(lark_tree) + + if self.get_field_type(dotted_path, lark_tree) != "nested": + raise self.error(lark_tree, "Expected a nested field") + + try: + self.schema = self.schema[dotted_path] + yield + finally: + self.schema = schema + + def nested_query(self, tree): + # field_tree, query_tree = tree.child_trees + # + # with self.nest(field_tree) as field: + # return NestedQuery(field, self.visit(query_tree)) + + raise self.error(tree, "Nested queries are not yet supported") + + def field_value_expression(self, tree): + field_tree, expr = tree.child_trees + + with self.scope(self.visit(field_tree)) as field: + # check the field against the schema + self.get_field_type(field.name, field_tree) + return FieldComparison(field, self.visit(expr)) + + def field_range_expression(self, tree): + field_tree, operator, literal = tree.children + with self.scope(self.visit(field_tree)) as field: + value = self.convert_value(field.name, self.visit(literal), literal) + return FieldRange(field, operator, Value.from_python(value)) + + def or_list_of_values(self, tree): + self.assert_lower_token(*tree.child_tokens) + return OrValues([self.visit(t) for t in tree.child_trees]) + + def and_list_of_values(self, tree): + self.assert_lower_token(*tree.child_tokens) + return AndValues([self.visit(t) for t in tree.child_trees]) + + def not_list_of_values(self, tree): + self.assert_lower_token(*tree.child_tokens) + return NotValue(self.visit(tree.children[-1])) + + def literal(self, tree): + return self.unescape_literal(tree.children[0]) + + def field(self, tree): + literal = self.visit(tree.children[0]) + return Field(eql.utils.to_unicode(literal)) + + def value(self, tree): + if self.scoped_field is None: + raise self.error(tree, "Value not tied to field") + + field_name = self.scoped_field.name + token = tree.children[0] + value = self.unescape_literal(token) + + if token.type == "UNQUOTED_LITERAL" and "*" in token.value: + field_type = self.get_field_type(field_name) + if len(value.replace("*", "")) == 0: + return Exists() + + if field_type is not None and field_type not in ("keyword", "wildcard"): + raise self.error(tree, "Unable to perform wildcard on field {field} of {type}", + field=field_name, type=field_type) + + return Wildcard(token.value) + + # try to convert the value to the appropriate type + # example: 1 -> "1" if the field is actually keyword + value = self.convert_value(field_name, value, tree) + return Value.from_python(value) + + +def lark_parse(text): + if not text.strip(): + raise KqlParseError("No query provided", 0, 0, "") + + walker = BaseKqlParser(text) + + try: + return lark_parser.parse(text) + except UnexpectedEOF: + raise KqlParseError("Unexpected EOF", len(walker.lines), len(walker.lines[-1].strip()), walker.lines[-1]) + except LarkError as exc: + raise KqlParseError("Invalid syntax", exc.line - 1, exc.column - 1, + '\n'.join(walker.lines[exc.line - 2:exc.line]))