297 lines
11 KiB
Python
297 lines
11 KiB
Python
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
# or more contributor license agreements. Licensed under the Elastic License
|
|
# 2.0; you may not use this file except in compliance with the Elastic License
|
|
# 2.0.
|
|
|
|
import contextlib
|
|
import os
|
|
import re
|
|
|
|
import eql
|
|
from lark import Token # noqa: F401
|
|
from lark import Tree, Lark
|
|
from lark.exceptions import LarkError, UnexpectedEOF
|
|
from lark.visitors import Interpreter
|
|
|
|
from kql.errors import KqlParseError
|
|
from .ast import * # noqa: F403
|
|
|
|
|
|
STRING_FIELDS = ("keyword", "text")
|
|
|
|
|
|
class KvTree(Tree):
|
|
|
|
@property
|
|
def child_trees(self):
|
|
return [child for child in self.children if isinstance(child, KvTree)]
|
|
|
|
@property
|
|
def child_tokens(self):
|
|
return [child for child in self.children if isinstance(child, Token)]
|
|
|
|
|
|
grammar_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "kql.g")
|
|
with open(grammar_file, "rt") as f:
|
|
grammar = f.read()
|
|
|
|
lark_parser = Lark(grammar, propagate_positions=True, tree_class=KvTree, start=['query'], parser='lalr')
|
|
|
|
|
|
class BaseKqlParser(Interpreter):
|
|
NON_SPACE_WS = re.compile(r"[^\S ]+")
|
|
ip_regex = re.compile("^" + eql.functions.CidrMatch.ip_re + "(/([0-2]?[0-9]|3[0-2]))?$")
|
|
|
|
unquoted_escapes = {"\\t": "\t", "\\r": "\r", "\\n": "\n"}
|
|
|
|
for special in "\\():<>\"*{}]":
|
|
unquoted_escapes["\\" + special] = special
|
|
|
|
unquoted_regex = re.compile("(" + "|".join(re.escape(e) for e in sorted(unquoted_escapes)) + ")")
|
|
|
|
quoted_escapes = {"\\t": "\t", "\\r": "\r", "\\n": "\n", "\\\\": "\\", "\\\"": "\""}
|
|
quoted_regex = re.compile("(" + "|".join(re.escape(e) for e in sorted(quoted_escapes)) + ")")
|
|
|
|
def __init__(self, text, schema=None):
|
|
self.text = text
|
|
self.lines = [t.rstrip("\r\n") for t in self.text.splitlines(True)]
|
|
self.scoped_field = None
|
|
self.mapping_schema = schema
|
|
self.star_fields = []
|
|
|
|
if schema:
|
|
for field, field_type in schema.items():
|
|
if "*" in field:
|
|
parts = field.split("*")
|
|
pattern = re.compile("^{regex}$".format(regex=".*?".join(re.escape(w) for w in parts)))
|
|
self.star_fields.append(pattern)
|
|
|
|
def assert_lower_token(self, *tokens):
|
|
for token in tokens:
|
|
if str(token) != str(token).lower():
|
|
raise self.error(token, "Expected '{lower}' but got '{token}'".format(token=token, lower=str(token).lower()))
|
|
|
|
def error(self, node, message, end=False, cls=KqlParseError, width=None, **kwargs):
|
|
"""Generate an error exception but dont raise it."""
|
|
if kwargs:
|
|
message = message.format(**kwargs)
|
|
|
|
line_number = node.line - 1
|
|
column = node.column - 1
|
|
|
|
# get more lines for more informative error messages. three before + two after
|
|
before = self.lines[:line_number + 1][-3:]
|
|
after = self.lines[line_number + 1:][:3]
|
|
|
|
source = '\n'.join(b for b in before)
|
|
trailer = '\n'.join(a for a in after)
|
|
|
|
# Determine if the error message can easily look like this
|
|
# ^^^^
|
|
if width is None and not end and node.line == node.end_line:
|
|
if not self.NON_SPACE_WS.search(self.lines[line_number][column:node.end_column]):
|
|
width = node.end_column - node.column
|
|
|
|
if width is None:
|
|
width = 1
|
|
|
|
return cls(message, line_number, column, source, width=width, trailer=trailer)
|
|
|
|
def __default__(self, tree):
|
|
raise NotImplementedError("Unable to visit tree {} of type: {}".format(tree, tree.data))
|
|
|
|
def unescape_literal(self, token): # type: (Token) -> (int|float|str|bool)
|
|
if token.type == "QUOTED_STRING":
|
|
return self.convert_quoted_string(token.value)
|
|
else:
|
|
return self.convert_unquoted_literal(token.value)
|
|
|
|
@contextlib.contextmanager
|
|
def scope(self, field):
|
|
# with self.scope(field) as field:
|
|
# ...
|
|
self.scoped_field = field
|
|
yield field
|
|
self.scoped_field = None
|
|
|
|
def get_field_type(self, dotted_path, lark_tree=None):
|
|
matches_pattern = any(regex.match(dotted_path) for regex in self.star_fields)
|
|
|
|
if self.mapping_schema is not None:
|
|
if lark_tree is not None and dotted_path not in self.mapping_schema and not matches_pattern:
|
|
raise self.error(lark_tree, "Unknown field")
|
|
|
|
return self.mapping_schema.get(dotted_path)
|
|
|
|
@staticmethod
|
|
def get_literal_type(literal_value):
|
|
if isinstance(literal_value, bool):
|
|
return "boolean"
|
|
elif isinstance(literal_value, float):
|
|
return "float"
|
|
elif isinstance(literal_value, int):
|
|
return "long"
|
|
elif eql.utils.is_string(literal_value):
|
|
# this will be converted when compared to the field
|
|
return "keyword"
|
|
elif literal_value is None:
|
|
return "null"
|
|
else:
|
|
raise NotImplementedError("Unknown literal type: {}".format(type(literal_value).__name__))
|
|
|
|
def convert_value(self, field_name, python_value, value_tree):
|
|
field_type = self.get_field_type(field_name)
|
|
value_type = self.get_literal_type(python_value)
|
|
|
|
if field_type is not None and field_type != value_type:
|
|
if field_type in STRING_FIELDS:
|
|
return eql.utils.to_unicode(python_value)
|
|
elif field_type in ("float", "long"):
|
|
try:
|
|
return float(python_value) if field_type == "float" else int(python_value)
|
|
except ValueError:
|
|
pass
|
|
elif field_type == "ip" and value_type == "keyword":
|
|
if "::" in python_value or self.ip_regex.match(python_value) is not None:
|
|
return python_value
|
|
|
|
raise self.error(value_tree, "Value doesn't match {field}'s type: {type}",
|
|
field=field_name, type=field_type)
|
|
|
|
# otherwise, there's nothing to convert
|
|
return python_value
|
|
|
|
@classmethod
|
|
def convert_unquoted_literal(cls, text):
|
|
if text == "true":
|
|
return True
|
|
elif text == "false":
|
|
return False
|
|
elif text == "null":
|
|
return None
|
|
else:
|
|
for numeric in (int, float):
|
|
try:
|
|
return numeric(text)
|
|
except ValueError:
|
|
pass
|
|
|
|
text = cls.unquoted_regex.sub(lambda r: cls.unquoted_escapes[r.group()], text)
|
|
return text
|
|
|
|
@classmethod
|
|
def convert_quoted_string(cls, text):
|
|
inner_text = text[1:-1]
|
|
unescaped = cls.quoted_regex.sub(lambda r: cls.quoted_escapes[r.group()], inner_text)
|
|
return unescaped
|
|
|
|
|
|
class KqlParser(BaseKqlParser):
|
|
def or_query(self, tree):
|
|
self.assert_lower_token(*tree.child_tokens)
|
|
terms = [self.visit(t) for t in tree.child_trees]
|
|
return OrExpr(terms)
|
|
|
|
def and_query(self, tree):
|
|
self.assert_lower_token(*tree.child_tokens)
|
|
terms = [self.visit(t) for t in tree.child_trees]
|
|
return AndExpr(terms)
|
|
|
|
def not_query(self, tree):
|
|
self.assert_lower_token(*tree.child_tokens)
|
|
return NotExpr(self.visit(tree.children[-1]))
|
|
|
|
@contextlib.contextmanager
|
|
def nest(self, lark_tree):
|
|
schema = self.mapping_schema
|
|
dotted_path = self.visit(lark_tree)
|
|
|
|
if self.get_field_type(dotted_path, lark_tree) != "nested":
|
|
raise self.error(lark_tree, "Expected a nested field")
|
|
|
|
try:
|
|
self.mapping_schema = self.mapping_schema[dotted_path]
|
|
yield
|
|
finally:
|
|
self.mapping_schema = schema
|
|
|
|
def nested_query(self, tree):
|
|
# field_tree, query_tree = tree.child_trees
|
|
#
|
|
# with self.nest(field_tree) as field:
|
|
# return NestedQuery(field, self.visit(query_tree))
|
|
|
|
raise self.error(tree, "Nested queries are not yet supported")
|
|
|
|
def field_value_expression(self, tree):
|
|
field_tree, expr = tree.child_trees
|
|
|
|
with self.scope(self.visit(field_tree)) as field:
|
|
# check the field against the schema
|
|
self.get_field_type(field.name, field_tree)
|
|
return FieldComparison(field, self.visit(expr))
|
|
|
|
def field_range_expression(self, tree):
|
|
field_tree, operator, literal = tree.children
|
|
with self.scope(self.visit(field_tree)) as field:
|
|
value = self.convert_value(field.name, self.visit(literal), literal)
|
|
return FieldRange(field, operator, Value.from_python(value))
|
|
|
|
def or_list_of_values(self, tree):
|
|
self.assert_lower_token(*tree.child_tokens)
|
|
return OrValues([self.visit(t) for t in tree.child_trees])
|
|
|
|
def and_list_of_values(self, tree):
|
|
self.assert_lower_token(*tree.child_tokens)
|
|
return AndValues([self.visit(t) for t in tree.child_trees])
|
|
|
|
def not_list_of_values(self, tree):
|
|
self.assert_lower_token(*tree.child_tokens)
|
|
return NotValue(self.visit(tree.children[-1]))
|
|
|
|
def literal(self, tree):
|
|
return self.unescape_literal(tree.children[0])
|
|
|
|
def field(self, tree):
|
|
literal = self.visit(tree.children[0])
|
|
return Field(eql.utils.to_unicode(literal))
|
|
|
|
def value(self, tree):
|
|
if self.scoped_field is None:
|
|
raise self.error(tree, "Value not tied to field")
|
|
|
|
field_name = self.scoped_field.name
|
|
token = tree.children[0]
|
|
value = self.unescape_literal(token)
|
|
|
|
if token.type == "UNQUOTED_LITERAL" and "*" in token.value:
|
|
field_type = self.get_field_type(field_name)
|
|
if len(value.replace("*", "")) == 0:
|
|
return Exists()
|
|
|
|
if field_type is not None and field_type not in ("keyword", "wildcard"):
|
|
raise self.error(tree, "Unable to perform wildcard on field {field} of {type}",
|
|
field=field_name, type=field_type)
|
|
|
|
return Wildcard(token.value)
|
|
|
|
# try to convert the value to the appropriate type
|
|
# example: 1 -> "1" if the field is actually keyword
|
|
value = self.convert_value(field_name, value, tree)
|
|
return Value.from_python(value)
|
|
|
|
|
|
def lark_parse(text):
|
|
if not text.strip():
|
|
raise KqlParseError("No query provided", 0, 0, "")
|
|
|
|
walker = BaseKqlParser(text)
|
|
|
|
try:
|
|
return lark_parser.parse(text)
|
|
except UnexpectedEOF:
|
|
raise KqlParseError("Unexpected EOF", len(walker.lines), len(walker.lines[-1].strip()), walker.lines[-1])
|
|
except LarkError as exc:
|
|
raise KqlParseError("Invalid syntax", exc.line - 1, exc.column - 1,
|
|
'\n'.join(walker.lines[exc.line - 2:exc.line]))
|