From 5f867dbb72860a0ce3521beefb56f135f98f0300 Mon Sep 17 00:00:00 2001 From: Ross Wolf <31489089+rw-access@users.noreply.github.com> Date: Wed, 22 Jul 2020 13:05:45 -0400 Subject: [PATCH] Add KQL -> DSL conversion (#81) * Add KQL -> DSL converter * Lint with black to 120 chars * Add more tests and flatten shoulds * Fix NotValue conversion to DSL --- kql/__init__.py | 25 ++++++- kql/dsl.py | 117 ++++++++++++++++++++++++++++++ kql/evaluator.py | 9 ++- kql/kql2eql.py | 2 - tests/kuery/test_dsl.py | 156 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 301 insertions(+), 8 deletions(-) create mode 100644 kql/dsl.py create mode 100644 tests/kuery/test_dsl.py diff --git a/kql/__init__.py b/kql/__init__.py index b3af0e336..9469b172c 100644 --- a/kql/__init__.py +++ b/kql/__init__.py @@ -5,6 +5,7 @@ import eql from . import ast +from .dsl import ToDsl from .eql2kql import Eql2Kql from .errors import KqlParseError, KqlCompileError from .evaluator import FilterGenerator @@ -14,17 +15,29 @@ from .parser import lark_parse, KqlParser __version__ = '0.1.4' __all__ = ( "ast", - "to_eql", - "lint", - "parse", "from_eql", "get_evaluator", "KqlParseError", "KqlCompileError", + "lint", + "parse", + "to_dsl", + "to_eql", ) +def to_dsl(parsed, optimize=True, schema=None): + """Convert KQL to Elasticsearch Query DSL.""" + if not isinstance(parsed, ast.KqlNode): + parsed = parse(parsed, optimize, schema) + + return ToDsl.convert(parsed) + + def to_eql(text, optimize=True, schema=None): + if isinstance(text, bytes): + text = text.decode("utf-8") + lark_parsed = lark_parse(text) converted = KqlToEQL(text, schema=schema).visit(lark_parsed) @@ -32,6 +45,9 @@ def to_eql(text, optimize=True, schema=None): def parse(text, optimize=True, schema=None): + if isinstance(text, bytes): + text = text.decode("utf-8") + lark_parsed = lark_parse(text) converted = KqlParser(text, schema=schema).visit(lark_parsed) @@ -39,6 +55,9 @@ def parse(text, optimize=True, schema=None): def lint(text): + if isinstance(text, bytes): + text = text.decode("utf-8") + return parse(text, optimize=True).render() diff --git a/kql/dsl.py b/kql/dsl.py new file mode 100644 index 000000000..2edc48adb --- /dev/null +++ b/kql/dsl.py @@ -0,0 +1,117 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +from collections import defaultdict +from eql import Walker +from .errors import KqlCompileError + + +def boolean(**kwargs): + """Wrap a query in a boolean term and optimize while building.""" + assert len(kwargs) == 1 + [(boolean_type, children)] = kwargs.items() + + if not isinstance(children, list): + children = [children] + + dsl = defaultdict(list) + + if boolean_type in ("must", "filter"): + # safe to convert and(and(x), y) -> and(x, y) + for child in children: + if list(child) == ["bool"]: + for child_type, child_terms in child["bool"].items(): + if child_type in ("must", "filter"): + dsl[child_type].extend(child_terms) + elif child_type == "should": + if "should" not in dsl: + dsl[child_type].extend(child_terms) + else: + dsl[boolean_type].append(boolean(should=child_terms)) + elif child_type == "must_not": + dsl[child_type].extend(child_terms) + elif child_type != "minimum_should_match": + raise ValueError("Unknown term {}: {}".format(child_type, child_terms)) + else: + dsl[boolean_type].append(child) + + elif boolean_type == "should": + # can flatten `should` of `should` + for child in children: + if list(child) == ["bool"] and set(child["bool"]).issubset({"should", "minimum_should_match"}): + dsl["should"].extend(child["bool"]["should"]) + else: + dsl[boolean_type].append(child) + + elif boolean_type == "must_not" and len(children) == 1: + # must_not: [{bool: {must: x}}] -> {must_not: x} + child = children[0] + if list(child) == ["bool"] and list(child["bool"]) in (["filter"], ["must"]): + negated, = child["bool"].values() + dsl = {"must_not": negated} + else: + dsl = {"must_not": children} + + else: + dsl = dict(kwargs) + + if "should" in dsl: + dsl.update(minimum_should_match=1) + + dsl = {"bool": dict(dsl)} + return dsl + + +class ToDsl(Walker): + + def _walk_default(self, node, *args, **kwargs): + raise KqlCompileError("Unable to convert {}".format(node)) + + def _walk_exists(self, _): + return lambda field: {"exists": {"field": field}} + + def _walk_wildcard(self, tree): + return lambda field: {"query_string": {"fields": [field], "query": tree.value}} + + def _walk_value(self, tree): + return lambda field: {"match": {field: tree.value}} + + def _walk_field(self, field): + return field.name + + def _walk_field_range(self, tree): + operator_map = {"<": "lt", "<=": "lte", ">=": "gte", ">": "gt"} + field = self.walk(tree.field) + return {"range": {field: {operator_map[tree.operator]: tree.value.value}}} + + def _walk_not_expr(self, tree): + return boolean(must_not=[self.walk(tree.expr)]) + + def _walk_and_expr(self, tree): + return boolean(filter=[self.walk(node) for node in tree.items]) + + def _walk_or_expr(self, tree): + return boolean(should=[self.walk(node) for node in tree.items]) + + def _walk_and_values(self, tree): + children = [self.walk(node) for node in tree.items] + return lambda field: boolean(filter=[child(field) for child in children]) + + def _walk_or_values(self, tree): + children = [self.walk(node) for node in tree.items] + return lambda field: boolean(should=[child(field) for child in children]) + + def _walk_not_value(self, tree): + child = self.walk(tree.value) + return lambda field: boolean(must_not=[child(field)]) + + def _walk_field_comparison(self, tree): + field = self.walk(tree.field) + value_fn = self.walk(tree.value) + + return value_fn(field) + + @classmethod + def convert(cls, tree): + return boolean(filter=[cls().walk(tree)]) diff --git a/kql/evaluator.py b/kql/evaluator.py index 47b74db8b..649c565f4 100644 --- a/kql/evaluator.py +++ b/kql/evaluator.py @@ -32,16 +32,19 @@ class FilterGenerator(Walker): def get_terms(cls, document, path): if isinstance(document, (tuple, list)): for d in document: - yield from cls.get_terms(d, path) + for term in cls.get_terms(d, path): + yield term elif isinstance(document, dict): document = document.get(path[0]) path = path[1:] if len(path) > 0: - yield from cls.get_terms(document, path) + for term in cls.get_terms(document, path): + yield term elif isinstance(document, (tuple, list)): - yield from iter(document) + for term in document: + yield term elif document is not None: yield document diff --git a/kql/kql2eql.py b/kql/kql2eql.py index 96efbfad8..0bab5d741 100755 --- a/kql/kql2eql.py +++ b/kql/kql2eql.py @@ -2,8 +2,6 @@ # or more contributor license agreements. Licensed under the Elastic License; # you may not use this file except in compliance with the Elastic License. -#!/usr/bin/env python - import eql from .parser import BaseKqlParser diff --git a/tests/kuery/test_dsl.py b/tests/kuery/test_dsl.py new file mode 100644 index 000000000..7a7a4851c --- /dev/null +++ b/tests/kuery/test_dsl.py @@ -0,0 +1,156 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import unittest +import kql + + +class TestKQLtoDSL(unittest.TestCase): + def validate(self, kql_source, dsl, **kwargs): + actual_dsl = kql.to_dsl(kql_source, **kwargs) + self.assertListEqual(list(actual_dsl), ["bool"]) + self.assertDictEqual(actual_dsl["bool"], dsl) + + def test_field_match(self): + def match(**kv): + return {"filter": [{"match": kv}]} + + self.validate("user:bob", match(user="bob")) + self.validate("number:-1", match(number=-1)) + self.validate("number:1.1", match(number=1.1)) + self.validate("boolean:true", match(boolean=True)) + self.validate("boolean:false", match(boolean=False)) + + def test_field_exists(self): + self.validate("user:*", {"filter": [{"exists": {"field": "user"}}]}) + + def test_field_inequality(self): + def rng(op, val): + return {"filter": [{"range": {"field": {op: val}}}]} + + self.validate("field < value", rng("lt", "value")) + self.validate("field > -1", rng("gt", -1)) + self.validate("field <= 1.1", rng("lte", 1.1)) + self.validate("field >= 0", rng("gte", 0)) + self.validate("field >= abc", rng("gte", "abc")) + + def test_or_query(self): + self.validate( + "field:value or field2:value2", + {"should": [{"match": {"field": "value"}}, {"match": {"field2": "value2"}}], "minimum_should_match": 1}, + ) + + def test_and_query(self): + self.validate( + "field:value and field2:value2", + {"filter": [{"match": {"field": "value"}}, {"match": {"field2": "value2"}}]}, + ) + + def test_not_query(self): + self.validate("not field:value", {"must_not": [{"match": {"field": "value"}}]}) + self.validate("field:(not value)", {"must_not": [{"match": {"field": "value"}}]}) + self.validate("field:(a and not b)", { + "filter": [{"match": {"field": "a"}}], + "must_not": [{"match": {"field": "b"}}] + }) + self.validate( + "not field:value and not field2:value2", + {"must_not": [{"match": {"field": "value"}}, {"match": {"field2": "value2"}}]}, + ) + self.validate( + "not (field:value or field2:value2)", + { + "must_not": [ + { + "bool": { + "minimum_should_match": 1, + "should": [{"match": {"field": "value"}}, {"match": {"field2": "value2"}}], + } + } + ] + }, + optimize=False, + ) + + self.validate("not (field:value and field2:value2)", + { + "must_not": [ + {"match": {"field": "value"}}, + {"match": {"field2": "value2"}} + ] + }) + + def test_optimizations(self): + self.validate( + "(field:value or field2:value2) and field3:value3", + { + "should": [{"match": {"field": "value"}}, {"match": {"field2": "value2"}}], + "filter": [{"match": {"field3": "value3"}}], + "minimum_should_match": 1, + }, + ) + + self.validate( + "(field:value and field2:value2) or field3:value3", + { + "should": [ + {"bool": {"filter": [{"match": {"field": "value"}}, {"match": {"field2": "value2"}}]}}, + {"match": {"field3": "value3"}}, + ], + "minimum_should_match": 1, + }, + ) + + self.validate( + "a:(v1 or v2 or v3) or b:(v4 or v5)", + { + "should": [ + {"match": {"a": "v1"}}, + {"match": {"a": "v2"}}, + {"match": {"a": "v3"}}, + {"match": {"b": "v4"}}, + {"match": {"b": "v5"}}, + ], + "minimum_should_match": 1, + }, + ) + + self.validate( + "a:(v1 or v2 or v3) and b:(v4 or v5)", + { + "should": [ + {"match": {"a": "v1"}}, + {"match": {"a": "v2"}}, + {"match": {"a": "v3"}} + ], + "filter": [ + { + "bool": { + "should": [ + {"match": {"b": "v4"}}, + {"match": {"b": "v5"}} + ], + "minimum_should_match": 1 + } + } + ], + "minimum_should_match": 1, + }, + ) + + self.validate( + "(field:value and not field2:value2) or field3:value3", + { + "should": [ + { + "bool": { + "filter": [{"match": {"field": "value"}}], + "must_not": [{"match": {"field2": "value2"}}], + } + }, + {"match": {"field3": "value3"}}, + ], + "minimum_should_match": 1, + }, + )