[FR] Independently package kql / kibana and bump to py3.12 (#3514)

This commit is contained in:
Mika Ayenson
2024-03-14 20:18:32 -05:00
committed by GitHub
parent 3d2a36be32
commit d26981f712
31 changed files with 159 additions and 53 deletions
+16
View File
@@ -0,0 +1,16 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
"""Wrapper around Kibana APIs for the Security Application."""
from .connector import Kibana
from .resources import RuleResource, Signal
__version__ = '0.1.0'
__all__ = (
"Kibana",
"RuleResource",
"Signal"
)
+230
View File
@@ -0,0 +1,230 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
"""Wrapper around requests.Session for HTTP requests to Kibana."""
import atexit
import base64
import json
import sys
import threading
import uuid
import requests
from elasticsearch import Elasticsearch
_context = threading.local()
class Kibana(object):
"""Wrapper around the Kibana SIEM APIs."""
CACHED = False
def __init__(self, cloud_id=None, kibana_url=None, verify=True, elasticsearch=None, space=None):
""""Open a session to the platform."""
self.authenticated = False
self.session = requests.Session()
self.session.verify = verify
self.verify = verify
self.cloud_id = cloud_id
self.kibana_url = kibana_url.rstrip('/') if kibana_url else None
self.elastic_url = None
self.space = space if space and space.lower() != 'default' else None
self.status = None
self.provider_name = None
self.provider_type = None
if self.cloud_id:
self.cluster_name, cloud_info = self.cloud_id.split(":")
self.domain, self.es_uuid, self.kibana_uuid = \
base64.b64decode(cloud_info.encode("utf-8")).decode("utf-8").split("$")
if self.domain.endswith(':443'):
self.domain = self.domain[:-4]
kibana_url_from_cloud = f"https://{self.kibana_uuid}.{self.domain}:9243"
if self.kibana_url and self.kibana_url != kibana_url_from_cloud:
raise ValueError(f'kibana_url provided ({self.kibana_url}) does not match url derived from cloud_id '
f'{kibana_url_from_cloud}')
self.kibana_url = kibana_url_from_cloud
self.elastic_url = f"https://{self.es_uuid}.{self.domain}:9243"
self.provider_name = 'cloud-basic'
self.provider_type = 'basic'
self.session.headers.update({'Content-Type': "application/json", "kbn-xsrf": str(uuid.uuid4())})
self.elasticsearch = elasticsearch
if not verify:
from requests.packages.urllib3.exceptions import \
InsecureRequestWarning
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
atexit.register(self.__close)
@property
def version(self):
"""Get the semantic version."""
if self.status:
return self.status.get("version", {}).get("number")
def url(self, uri):
"""Get the full URL given a URI."""
assert self.kibana_url is not None
# If a space is defined update the URL accordingly
uri = uri.lstrip('/')
if self.space:
uri = "s/{}/{}".format(self.space, uri)
return f"{self.kibana_url}/{uri}"
def request(self, method, uri, params=None, data=None, error=True, verbose=True, raw=False, **kwargs):
"""Perform a RESTful HTTP request with JSON responses."""
params = params or {}
url = self.url(uri)
params = {k: v for k, v in params.items()}
body = None
if data is not None:
body = json.dumps(data)
response = self.session.request(method, url, params=params, data=body, **kwargs)
if response.status_code != 200:
# retry once
response = self.session.request(method, url, params=params, data=body, **kwargs)
if error:
try:
response.raise_for_status()
except requests.exceptions.HTTPError:
if verbose:
print(response.content.decode("utf-8"), file=sys.stderr)
raise
if not response.content:
return
return response.content if raw else response.json()
def get(self, uri, params=None, data=None, error=True, **kwargs):
"""Perform an HTTP GET."""
return self.request('GET', uri, data=data, params=params, error=error, **kwargs)
def put(self, uri, params=None, data=None, error=True, **kwargs):
"""Perform an HTTP PUT."""
return self.request('PUT', uri, params=params, data=data, error=error, **kwargs)
def post(self, uri, params=None, data=None, error=True, **kwargs):
"""Perform an HTTP POST."""
return self.request('POST', uri, params=params, data=data, error=error, **kwargs)
def patch(self, uri, params=None, data=None, error=True, **kwargs):
"""Perform an HTTP PATCH."""
return self.request('PATCH', uri, params=params, data=data, error=error, **kwargs)
def delete(self, uri, params=None, error=True, **kwargs):
"""Perform an HTTP DELETE."""
return self.request('DELETE', uri, params=params, error=error, **kwargs)
def login(self, kibana_username, kibana_password, provider_type=None, provider_name=None):
"""Authenticate to Kibana using the API to update our cookies."""
payload = {'username': kibana_username, 'password': kibana_password}
path = '/internal/security/login'
try:
self.post(path, data=payload, error=True, verbose=False)
except requests.HTTPError as e:
# 7.10 changed the structure of the auth data
# providers dictated by Kibana configs in:
# https://www.elastic.co/guide/en/kibana/current/security-settings-kb.html#authentication-security-settings
# more details: https://discuss.elastic.co/t/kibana-7-10-login-issues/255201/2
if e.response.status_code == 400 and '[undefined]' in e.response.text:
provider_type = provider_type or self.provider_type or 'basic'
provider_name = provider_name or self.provider_name or 'basic'
payload = {
'params': payload,
'currentURL': '',
'providerType': provider_type,
'providerName': provider_name
}
self.post(path, data=payload, error=True)
else:
raise
# Kibana will authenticate against URLs which contain invalid spaces
if self.space:
self.verify_space(self.space)
self.authenticated = True
self.status = self.get("/api/status")
# create ES and force authentication
if self.elasticsearch is None and self.elastic_url is not None:
self.elasticsearch = Elasticsearch(hosts=[self.elastic_url], http_auth=(kibana_username, kibana_password),
verify_certs=self.verify)
self.elasticsearch.info()
# make chaining easier
return self
def add_cookie(self, cookie):
"""Add cookie to be used for auth (such as from an SSO session)."""
# https://www.elastic.co/guide/en/kibana/7.10/security-settings-kb.html#security-session-and-cookie-settings
self.session.headers['sid'] = cookie
self.session.cookies.set('sid', cookie)
self.status = self.get('/api/status')
self.authenticated = True
def logout(self):
"""Quit the current session."""
try:
self.get('/logout', raw=True, error=False)
except requests.exceptions.ConnectionError:
# for really short scoping from buildup to teardown, ES will cause a Max retry error
pass
self.status = None
self.authenticated = False
self.session = requests.Session()
self.elasticsearch = None
def __close(self):
if self.authenticated:
self.logout()
def __enter__(self):
"""Use the current Kibana instance for ``with`` syntax."""
if not hasattr(_context, "stack"):
_context.stack = []
# Backup the previous Kibana instance and bind the current one
_context.stack.append(self)
return self
def __exit__(self, exception_type, exception_value, traceback):
"""Use the current Kibana for ``with`` syntax."""
_context.stack.pop()
@classmethod
def current(cls) -> 'Kibana':
"""Get the currently used Kibana stack."""
stack = getattr(_context, "stack", [])
if len(stack) == 0:
raise RuntimeError("No Kibana connector in scope!")
return stack[-1]
def verify_space(self, space):
"""Verify a space is valid."""
spaces = self.get('/api/spaces/space')
space_names = [s['name'] for s in spaces]
if space not in space_names:
raise ValueError(f'Unknown Kibana space: {space}')
def current_user(self):
"""Retrieve info for currently authenticated user."""
if self.authenticated:
return self.get('/internal/security/me')
+196
View File
@@ -0,0 +1,196 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import datetime
from typing import List, Optional, Type
from .connector import Kibana
DEFAULT_PAGE_SIZE = 10
class BaseResource(dict):
BASE_URI = ""
ID_FIELD = "id"
@property
def id(self):
return self.get(self.ID_FIELD)
@classmethod
def bulk_create(cls, resources: list):
for r in resources:
assert isinstance(r, cls)
responses = Kibana.current().post(cls.BASE_URI + "/_bulk_create", data=resources)
return [cls(r) for r in responses]
def create(self):
response = Kibana.current().post(self.BASE_URI, data=self)
self.update(response)
return self
@classmethod
def find(cls, per_page=None, **params) -> iter:
if per_page is None:
per_page = DEFAULT_PAGE_SIZE
params.setdefault("sort_field", "_id")
params.setdefault("sort_order", "asc")
return ResourceIterator(cls, cls.BASE_URI + "/_find", per_page=per_page, **params)
@classmethod
def from_id(cls, resource_id) -> 'BaseResource':
return Kibana.current().get(cls.BASE_URI, params={cls.ID_FIELD: resource_id})
def put(self):
response = Kibana.current().put(self.BASE_URI, data=self.to_dict())
self._update_from(response)
return self
def delete(self):
return Kibana.current().delete(self.BASE_URI, params={"id": self.id})
class ResourceIterator(object):
def __init__(self, cls: Type[BaseResource], uri: str, per_page: int, **params: dict):
self.cls = cls
self.uri = uri
self.params = params
self.page = 0
self.per_page = per_page
self.fetched = 0
self.current = None
self.total = None
self.batch = []
self.batch_pos = 0
self.kibana = Kibana.current()
def __iter__(self):
return self
def _batch(self):
params = dict(per_page=self.per_page, page=self.page + 1, **self.params)
response = self.kibana.get(self.uri, params=params, error=True)
self.page = response["page"]
self.per_page = response["perPage"]
self.total = response["total"]
self.batch = response["data"]
self.batch_pos = 0
self.fetched += len(self.batch)
def __next__(self) -> BaseResource:
if self.total is None or 0 < self.batch_pos == len(self.batch) == self.per_page:
self._batch()
if self.batch_pos < len(self.batch):
result = self.cls(self.batch[self.batch_pos])
self.batch_pos += 1
return result
raise StopIteration()
class RuleResource(BaseResource):
BASE_URI = "/api/detection_engine/rules"
@staticmethod
def _add_internal_filter(is_internal: bool, params: dict) -> dict:
custom_filter = f'alert.attributes.tags:"__internal_immutable:{str(is_internal).lower()}"'
if params.get("filter"):
params["filter"] = f"({params['filter']}) and ({custom_filter})"
else:
params["filter"] = custom_filter
return params
@classmethod
def find_custom(cls, **params):
params = cls._add_internal_filter(False, params)
return cls.find(**params)
@classmethod
def find_elastic(cls, **params):
# GET params:
# * `sort_field`
# * `sort_order`
# * `filter` (accepts KQL)
# alert.attributes.name:mshta
# alert.attributes.enabled:true/false
#
# ...
# i.e. Rule.find_elastic(filter="alert.attributes.name:mshta")
params = cls._add_internal_filter(True, params)
return cls.find(**params)
def put(self):
# id and rule_id are mutually exclusive
rule_id = self.get("rule_id")
self.pop("rule_id", None)
try:
# apparently Kibana doesn't like `rule_id` for existing documents
return super(RuleResource, self).update()
except Exception:
# if it fails, restore the id back
if rule_id:
self["rule_id"] = rule_id
raise
class Signal(BaseResource):
BASE_URI = "/api/detection_engine/signals"
def __init__(self):
raise NotImplementedError("Signals can't be instantiated yet")
@classmethod
def search(cls, query_dsl: dict, size: Optional[int] = 10):
payload = dict(size=size, **query_dsl)
return Kibana.current().post(f"{cls.BASE_URI}/search", data=payload)
@classmethod
def last_signal(cls) -> (int, datetime.datetime):
query_dsl = {
"aggs": {
"lastSeen": {"max": {"field": "@timestamp"}}
},
'query': {
"bool": {
"filter": [
{"match": {"signal.status": "open"}}
]
}
},
"size": 0,
"track_total_hits": True
}
response = cls.search(query_dsl)
last_seen = response.get("aggregations", {}).get("last_seen", {}).get("value_as_string")
num_signals = response.get("hits", {}).get("total", {}).get("value")
if last_seen is not None:
last_seen = datetime.datetime.strptime(last_seen, "%Y-%m-%dT%H:%M:%S.%f%z")
return num_signals, last_seen
@classmethod
def all(cls, size: Optional[int] = 10):
return cls.search({"query": {"bool": {"filter": {"match_all": {}}}}}, size=size)
@classmethod
def set_status_many(cls, signal_ids: List[str], status: str) -> dict:
return Kibana.current().post(f"{cls.BASE_URI}/status", data={"signal_ids": signal_ids, "status": status})
@classmethod
def close_many(cls, signal_ids: List[str]):
return cls.set_status_many(signal_ids, "closed")
@classmethod
def open_many(cls, signal_ids: List[str]):
return cls.set_status_many(signal_ids, "open")
+28
View File
@@ -0,0 +1,28 @@
[project]
name = "detection-rules-kibana"
version = "0.1.0"
description = "Kibana API utilities for Elastic Detection Rules"
license = {text = "Elastic License v2"}
keywords = ["Elastic", "Kibana", "Detection Rules", "Security", "Elasticsearch"]
classifiers = [
"Intended Audience :: Developers",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.12",
"Topic :: Security",
"Topic :: Software Development :: Build Tools",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Software Development",
]
requires-python = ">=3.12"
dependencies = [
"requests>=2.25,<3.0",
"elasticsearch~=8.1",
]
[project.urls]
Homepage = "https://github.com/elastic/detection-rules"
License = "https://github.com/elastic/detection-rules/blob/main/LICENSE.txt"
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
+80
View File
@@ -0,0 +1,80 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import eql
from . import ast
from .dsl import ToDsl
from .eql2kql import Eql2Kql
from .errors import KqlParseError, KqlCompileError
from .evaluator import FilterGenerator
from .kql2eql import KqlToEQL
from .parser import lark_parse, KqlParser
__version__ = '0.1.6'
__all__ = (
"ast",
"from_eql",
"get_evaluator",
"KqlParseError",
"KqlCompileError",
"lint",
"parse",
"to_dsl",
"to_eql",
)
def to_dsl(parsed, optimize=True, schema=None):
"""Convert KQL to Elasticsearch Query DSL."""
if not isinstance(parsed, ast.KqlNode):
parsed = parse(parsed, optimize, schema)
return ToDsl.convert(parsed)
def to_eql(text, optimize=True, schema=None):
if isinstance(text, bytes):
text = text.decode("utf-8")
lark_parsed = lark_parse(text)
converted = KqlToEQL(text, schema=schema).visit(lark_parsed)
return converted.optimize(recursive=True) if optimize else converted
def parse(text, optimize=True, schema=None):
if isinstance(text, bytes):
text = text.decode("utf-8")
lark_parsed = lark_parse(text)
converted = KqlParser(text, schema=schema).visit(lark_parsed)
return converted.optimize(recursive=True) if optimize else converted
def lint(text):
if isinstance(text, bytes):
text = text.decode("utf-8")
return parse(text, optimize=True).render()
def from_eql(tree, optimize=True):
if not isinstance(tree, eql.ast.EqlNode):
try:
tree = eql.parse_query(tree, implied_any=True)
except eql.EqlSemanticError:
tree = eql.parse_expression(tree)
converted = Eql2Kql().walk(tree)
return converted.optimize(recursive=True) if optimize else converted
def get_evaluator(tree, optimize=False):
if not isinstance(tree, ast.KqlNode):
tree = parse(tree, optimize=optimize)
return FilterGenerator().filter(tree)
+247
View File
@@ -0,0 +1,247 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import re
from string import Template
from eql.ast import BaseNode
from eql.errors import EqlCompileError
from eql.utils import is_number, is_string
__all__ = (
"KqlNode",
"Value",
"Null",
"Number",
"Boolean",
"List",
"Expression",
"String",
"Wildcard",
"NotValue",
"OrValues",
"AndValues",
"AndExpr",
"OrExpr",
"NotExpr",
"FieldComparison",
"Field",
"FieldRange",
"NestedQuery",
"Exists",
)
class KqlNode(BaseNode):
def optimize(self, recursive=True):
from .optimizer import Optimizer
return Optimizer().walk(self)
def _render(self):
return BaseNode.render(self)
def render(self, precedence=None, **kwargs):
"""Render an EQL node and add parentheses to support orders of operation."""
rendered = self._render(**kwargs)
if precedence is not None and self.precedence is not None and self.precedence > precedence:
return '({})'.format(rendered)
return rendered
class Value(KqlNode):
__slots__ = "value",
precedence = 1
def __init__(self, value):
self.value = value
@classmethod
def from_python(cls, value):
if value is None:
return Null()
elif is_string(value) and ('*' in value or '?' in value):
return Wildcard(value)
elif isinstance(value, bool):
return Boolean(value)
elif is_number(value):
return Number(value)
elif is_string(value):
return String(value)
else:
raise EqlCompileError("Unknown type {} for value {}".format(type(value).__name__, value))
class Null(Value):
def __init__(self, value=None):
Value.__init__(self, None)
def _render(self):
return "null"
class Number(Value):
def _render(self):
return str(self.value)
class Boolean(Value):
def _render(self):
return 'true' if self.value else 'false'
class String(Value):
unescapable = re.compile(r'^[^\\():<>"*{} \t\r\n]+$')
escapes = {"\t": "\\t", "\r": "\\r", "\"": "\\\""}
def _render(self):
# pass through as-is since nothing needs to be escaped
if self.unescapable.match(self.value) is not None:
return str(self.value)
regex = r"[{}]".format("".join(re.escape(s) for s in sorted(self.escapes)))
return '"{}"'.format(re.sub(regex, lambda r: self.escapes[r.group()], self.value))
class Wildcard(Value):
escapes = {"\t": "\\t", "\r": "\\r"}
slash_escaped = r'''^\\():<>"{} '''
def _render(self):
escaped = []
for char in self.value:
if char in self.slash_escaped:
escaped.append("\\")
escaped.append(char)
elif char in self.escapes:
escaped.append(self.escapes[char])
else:
escaped.append(char)
return ''.join(escaped)
class List(KqlNode):
__slots__ = "items",
precedence = Value.precedence + 1
operator = ""
template = Template("$items")
def __init__(self, items):
self.items = items
KqlNode.__init__(self)
@property
def delims(self):
return {"items": " {} ".format(self.operator)}
def __eq__(self, other):
from .optimizer import Optimizer
from functools import cmp_to_key
if type(self) == type(other):
a = list(self.items)
b = list(other.items)
a.sort(key=cmp_to_key(Optimizer.sort_key))
b.sort(key=cmp_to_key(Optimizer.sort_key))
return a == b
return False
class NotValue(KqlNode):
__slots__ = "value",
template = Template("not $value")
precedence = Value.precedence + 1
def __init__(self, value):
self.value = value
KqlNode.__init__(self)
class AndValues(List):
precedence = List.precedence + 1
operator = "and"
class OrValues(List):
precedence = AndValues.precedence + 1
operator = "or"
class Field(KqlNode):
__slots__ = "name",
precedence = Value.precedence
template = Template("$name")
def __init__(self, name):
self.name = name
KqlNode.__init__(self)
@property
def path(self):
return self.name.split(".")
@classmethod
def from_path(cls, path):
dotted = ".".join(path)
return cls(dotted)
class Expression(KqlNode):
"""Intermediate node for class hierarchy."""
class FieldRange(Expression, KqlNode):
__slots__ = "field", "operator", "value",
precedence = Field.precedence
template = Template("$field $operator $value")
def __init__(self, field, operator, value):
self.field = field
self.operator = operator
self.value = value
class NestedQuery(Expression):
__slots__ = "field", "expr",
precedence = Field.precedence + 1
template = Template("$field:{$expr}")
def __init__(self, field, expr):
self.field = field
self.expr = expr
class FieldComparison(Expression):
__slots__ = "field", "value",
precedence = FieldRange.precedence
template = Template("$field:$value")
def __init__(self, field, value):
self.field = field
self.value = value
class Exists(KqlNode):
__slots__ = tuple()
precedence = FieldComparison.precedence
template = Template("*")
class NotExpr(Expression):
__slots__ = "expr",
precedence = FieldComparison.precedence + 1
template = Template("not $expr")
def __init__(self, expr):
self.expr = expr
class AndExpr(Expression, List):
precedence = NotExpr.precedence + 1
operator = "and"
class OrExpr(Expression, List):
precedence = AndExpr.precedence + 1
operator = "or"
+118
View File
@@ -0,0 +1,118 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
from collections import defaultdict
from eql import Walker
from .errors import KqlCompileError
def boolean(**kwargs):
"""Wrap a query in a boolean term and optimize while building."""
assert len(kwargs) == 1
[(boolean_type, children)] = kwargs.items()
if not isinstance(children, list):
children = [children]
dsl = defaultdict(list)
if boolean_type in ("must", "filter"):
# safe to convert and(and(x), y) -> and(x, y)
for child in children:
if list(child) == ["bool"]:
for child_type, child_terms in child["bool"].items():
if child_type in ("must", "filter"):
dsl[child_type].extend(child_terms)
elif child_type == "should":
if "should" not in dsl:
dsl[child_type].extend(child_terms)
else:
dsl[boolean_type].append(boolean(should=child_terms))
elif child_type == "must_not":
dsl[child_type].extend(child_terms)
elif child_type != "minimum_should_match":
raise ValueError("Unknown term {}: {}".format(child_type, child_terms))
else:
dsl[boolean_type].append(child)
elif boolean_type == "should":
# can flatten `should` of `should`
for child in children:
if list(child) == ["bool"] and set(child["bool"]).issubset({"should", "minimum_should_match"}):
dsl["should"].extend(child["bool"]["should"])
else:
dsl[boolean_type].append(child)
elif boolean_type == "must_not" and len(children) == 1:
# must_not: [{bool: {must: x}}] -> {must_not: x}
child = children[0]
if list(child) == ["bool"] and list(child["bool"]) in (["filter"], ["must"]):
negated, = child["bool"].values()
dsl = {"must_not": negated}
else:
dsl = {"must_not": children}
else:
dsl = dict(kwargs)
if "should" in dsl:
dsl.update(minimum_should_match=1)
dsl = {"bool": dict(dsl)}
return dsl
class ToDsl(Walker):
def _walk_default(self, node, *args, **kwargs):
raise KqlCompileError("Unable to convert {}".format(node))
def _walk_exists(self, _):
return lambda field: {"exists": {"field": field}}
def _walk_wildcard(self, tree):
return lambda field: {"query_string": {"fields": [field], "query": tree.value}}
def _walk_value(self, tree):
return lambda field: {"match": {field: tree.value}}
def _walk_field(self, field):
return field.name
def _walk_field_range(self, tree):
operator_map = {"<": "lt", "<=": "lte", ">=": "gte", ">": "gt"}
field = self.walk(tree.field)
return {"range": {field: {operator_map[tree.operator]: tree.value.value}}}
def _walk_not_expr(self, tree):
return boolean(must_not=[self.walk(tree.expr)])
def _walk_and_expr(self, tree):
return boolean(filter=[self.walk(node) for node in tree.items])
def _walk_or_expr(self, tree):
return boolean(should=[self.walk(node) for node in tree.items])
def _walk_and_values(self, tree):
children = [self.walk(node) for node in tree.items]
return lambda field: boolean(filter=[child(field) for child in children])
def _walk_or_values(self, tree):
children = [self.walk(node) for node in tree.items]
return lambda field: boolean(should=[child(field) for child in children])
def _walk_not_value(self, tree):
child = self.walk(tree.value)
return lambda field: boolean(must_not=[child(field)])
def _walk_field_comparison(self, tree):
field = self.walk(tree.field)
value_fn = self.walk(tree.value)
return value_fn(field)
@classmethod
def convert(cls, tree):
return boolean(filter=[cls().walk(tree)])
+129
View File
@@ -0,0 +1,129 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import eql
from eql import DepthFirstWalker
from .ast import (
Value, String, OrValues, Field, Expression, FieldRange, FieldComparison,
NotExpr, AndExpr, OrExpr, Exists, Wildcard
)
class Eql2Kql(DepthFirstWalker):
def _walk_default(self, tree, *args, **kwargs):
if isinstance(tree, eql.ast.EqlNode):
raise eql.errors.EqlCompileError("Unable to convert {}".format(tree))
else:
return tree
def check_field_expression(self, tree):
if not isinstance(tree, Expression):
raise eql.errors.EqlCompileError("Expected expression, but got {}".format(repr(tree)))
return tree
def check_field_expressions(self, trees):
for tree in trees:
self.check_field_expression(tree)
return trees
def _walk_and(self, tree): # type: (eql.ast.And) -> AndExpr
return AndExpr(self.check_field_expressions(tree.terms))
def _walk_or(self, tree): # type: (eql.ast.Or) -> OrExpr
return OrExpr(self.check_field_expressions(tree.terms))
def _walk_not(self, tree): # type: (eql.ast.Not) -> NotExpr
return NotExpr(self.check_field_expression(tree.term))
def _walk_is_null(self, node): # type: (eql.ast.IsNull) -> FieldComparison
if not isinstance(node.expr, Field):
raise eql.errors.EqlCompileError("Unable to compare a non-field [{}] to null".format(node.expr))
return NotExpr(FieldComparison(node.expr, Exists()))
def _walk_is_not_null(self, node): # type: (eql.ast.IsNotNull) -> Expression
if not isinstance(node.expr, Field):
raise eql.errors.EqlCompileError("Unable to compare a non-field [{}] to null".format(node.expr))
return FieldComparison(node.expr, Exists())
def _walk_field(self, tree): # type: (eql.ast.Field) -> Field
if any(eql.utils.is_number(n) for n in tree.path):
raise eql.errors.EqlCompileError("Unable to convert array field: {}".format(tree))
return Field(tree.render())
def _walk_in_set(self, tree): # type: (eql.ast.InSet) -> FieldComparison
if not isinstance(tree.expression, Field) or not all(isinstance(v, Value) for v in tree.container):
raise eql.errors.EqlCompileError("Unable to convert `{}`".format(tree.expression, tree))
return FieldComparison(tree.expression, OrValues(tree.container))
def _walk_function_call(self, tree): # type: (eql.ast.FunctionCall) -> KqlNode
if tree.name in ("wildcard", "cidrMatch"):
if isinstance(tree.arguments[0], Field):
if tree.name == "wildcard":
args = []
for arg in tree.arguments[1:]:
if '*' in arg.value or '?' in arg.value:
args.append(Wildcard(arg.value))
else:
args.append(arg)
return FieldComparison(tree.arguments[0], OrValues(args))
else:
return FieldComparison(tree.arguments[0], OrValues(tree.arguments[1:]))
raise eql.errors.EqlCompileError("Unable to convert `{}`".format(tree))
def _walk_literal(self, tree):
return Value.from_python(tree.value)
def _walk_event_query(self, tree): # type: (eql.ast.EventQuery) -> KqlNode
if tree.event_type == eql.schema.EVENT_TYPE_ANY:
return self.check_field_expression(tree.query)
event_check = FieldComparison(Field("event.category"), String(tree.event_type))
# for `x where true` shorthand, drop the `where true`
if tree.query == Value.from_python(True):
return event_check
self.check_field_expression(tree.query)
return AndExpr([event_check, tree.query])
def _walk_filter_pipe(self, tree): # type: (eql.pipes.FilterPipe) -> KqlNode
return self.check_field_expression(tree.expression)
def _walk_piped_query(self, tree): # type: (eql.ast.PipedQuery) -> KqlNode
if not tree.pipes:
return tree.first
return AndExpr([tree.first] + tree.pipes)
LT, LE, EQ, NE, GE, GT = ('<', '<=', '==', '!=', '>=', '>')
flipped = {LT: GE, LE: GT,
EQ: EQ, NE: NE,
GE: LT, GT: LE}
def _walk_comparison(self, tree): # type: (eql.ast.Comparison) -> KqlNode
left = tree.left
op = tree.comparator
right = tree.right
# move the literal to the right
if isinstance(left, eql.ast.Literal):
left, right = right, left
op = self.flipped[op]
if isinstance(left, Field) and isinstance(right, Value):
if op == eql.ast.Comparison.EQ:
return FieldComparison(left, right)
elif op == eql.ast.Comparison.NE:
return NotExpr(FieldComparison(left, right))
else:
return FieldRange(left, op, right)
raise eql.errors.EqlCompileError("Unable to convert {}".format(tree))
+18
View File
@@ -0,0 +1,18 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
from eql import EqlError, EqlParseError, EqlCompileError
class KqlParseError(EqlParseError):
"""EQL Parsing Error."""
class KqlCompileError(EqlCompileError):
"""Class for KQL-specific compile errors."""
class KqlRuntimeError(EqlError):
"""Error for failures within the KQL evaluator."""
+156
View File
@@ -0,0 +1,156 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import operator
import re
import eql.ast
from eql import Walker, EqlCompileError, utils
from eql.functions import CidrMatch
from .errors import KqlRuntimeError, KqlCompileError
from .parser import is_ipaddress
class FilterGenerator(Walker):
__cidr_cache = {}
def _walk_default(self, node, *args, **kwargs):
raise KqlCompileError("Unable to convert {}".format(node))
@classmethod
def equals(cls, term, value):
"""Check if a term is equal to a value."""
if utils.is_string(term) and utils.is_string(value):
if is_ipaddress(term) and eql.utils.is_cidr_pattern(value):
# check for an ipv4 cidr
if value not in cls.__cidr_cache:
cls.__cidr_cache[value] = CidrMatch.get_callback(None, eql.ast.String(value))
return cls.__cidr_cache[value](term)
return term == value
@classmethod
def get_terms(cls, document, path):
if isinstance(document, (tuple, list)):
for d in document:
for term in cls.get_terms(d, path):
yield term
elif isinstance(document, dict):
document = document.get(path[0])
path = path[1:]
if len(path) > 0:
for term in cls.get_terms(document, path):
yield term
elif isinstance(document, (tuple, list)):
for term in document:
yield term
elif document is not None:
yield document
def _walk_value(self, tree, compare_function=None):
value = tree.value
compare_function = compare_function or self.equals
def check_value(term):
if term is None:
return False
if isinstance(term, list):
return any(check_value(t) for t in term)
if isinstance(term, (bool, float, int)) or utils.is_string(term):
v = value
if utils.is_string(v) and isinstance(term, (bool, int, float)):
if isinstance(v, bool):
v = v == "false"
if isinstance(term, int):
v = int(v)
elif isinstance(v, float):
v = float(v)
elif utils.is_string(term) and isinstance(v, (bool, int, float)):
v = utils.to_unicode(v)
return compare_function(term, v)
else:
raise KqlRuntimeError("Cannot compare value {}".format(term))
return check_value
def _walk_exists(self, _):
return lambda terms: any(t is not None for t in terms)
def _walk_wildcard(self, tree):
pattern = tree.value
regex = re.compile(".*?".join(map(re.escape, pattern.split("*"))), re.UNICODE | re.DOTALL)
return lambda terms: any(t is not None and regex.fullmatch(t) for t in terms)
def _walk_field(self, field):
path = field.name.split(".")
get_terms = self.get_terms
def callback(document):
terms = get_terms(document, path)
terms = list(terms)
return terms
return callback
def _walk_field_range(self, tree):
field = self.walk(tree.field)
operators = {"<": operator.lt, "<=": operator.le, ">=": operator.ge, ">": operator.gt}
check_range = self.walk(tree.value, operators[tree.operator])
return lambda doc: check_range(field(doc))
def _walk_nested_query(self, tree):
field = self.walk(tree.field)
expr = self.walk(tree.expr)
def check_nested(doc):
doc = field(doc)
if isinstance(doc, dict):
return expr(doc)
elif isinstance(doc, (list, tuple)):
return any(expr(d) for d in doc)
return check_nested
def _walk_list(self, trees, reduce_function, *args, **kwargs):
walked = [self.walk(item, *args, **kwargs) for item in trees.items]
return lambda x: reduce_function(item(x) for item in walked)
def _walk_not_expr(self, tree):
expr = self.walk(tree.expr)
return lambda doc: not expr(doc)
def _walk_and_expr(self, tree):
return self._walk_list(tree, all)
def _walk_or_expr(self, tree):
return self._walk_list(tree, any)
def _walk_and_values(self, tree):
return self._walk_list(tree, all)
def _walk_or_values(self, tree):
return self._walk_list(tree, any)
def _walk_not_value(self, tree):
expr = self.walk(tree.value)
return lambda value: not expr(value)
def _walk_field_comparison(self, tree):
field = self.walk(tree.field)
value = self.walk(tree.value)
return lambda doc: value(field(doc))
@classmethod
def filter(cls, expression):
return cls().walk(expression)
+51
View File
@@ -0,0 +1,51 @@
?query: or_query
?or_query: and_query (OR and_query)*
?and_query: not_query (AND not_query)*
?not_query: NOT? sub_query
?sub_query: "(" or_query ")"
| nested_query
?nested_query: field ":" "{" or_query "}"
| expression
?expression: field_range_expression
| field_value_expression
| value_expression
field_range_expression: field RANGE_OPERATOR literal
field_value_expression: field ":" list_of_values
?value_expression: value
?list_of_values: "(" or_list_of_values ")"
| value
?or_list_of_values: and_list_of_values (OR and_list_of_values)*
?and_list_of_values: not_list_of_values (AND not_list_of_values)*
?not_list_of_values: NOT? list_of_values
field: literal
value: QUOTED_STRING
| UNQUOTED_LITERAL
literal: QUOTED_STRING
| UNQUOTED_LITERAL
RANGE_OPERATOR: "<="
| ">="
| "<"
| ">"
UNQUOTED_LITERAL: UNQUOTED_CHAR+
UNQUOTED_CHAR: "\\" /[trn]/ // escaped whitespace
| "\\" /[\\():<>"*{}]/ // escaped specials
| "\\" (AND | OR | NOT) // escaped keywords
| "*" // wildcard
| /[^\\():<>"*{} \t\r\n]/ // anything else
QUOTED_STRING: /"(\\[tnr"\\]|[^\r\n"])*"/
OR.2: "or" | "OR"
AND.2: "and" | "AND"
NOT.2: "not" | "NOT"
WHITESPACE: (" " | "\r" | "\n" | "\t" )+
%ignore WHITESPACE
+106
View File
@@ -0,0 +1,106 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import eql
from .parser import BaseKqlParser
NOT_SUPPORTED_EQL_FIELDS = ["text"]
# https://github.com/elastic/eql/issues/17
class KqlToEQL(BaseKqlParser):
#
# Lark Visit methods
#
@staticmethod
def to_eql_field(name):
path = name.split(".")
return eql.ast.Field(path[0], path[1:])
def or_query(self, tree):
terms = [self.visit(t) for t in tree.child_trees]
return eql.ast.Or(terms)
def and_query(self, tree):
terms = [self.visit(t) for t in tree.child_trees]
return eql.ast.And(terms)
def not_query(self, tree):
return eql.ast.Not(self.visit(tree.children[-1]))
def nested_query(self, tree):
raise self.error(tree, "Unable to convert nested query to EQL")
def field_range_expression(self, tree):
field_tree, operator, literal_tree = tree.children
field_name = self.visit(field_tree)
# check the field against the schema
self.get_field_type(field_name, field_tree)
# get and convert the value
value = self.convert_value(field_name, self.visit(literal_tree), literal_tree)
literal = eql.ast.Literal.from_python(value)
field = self.to_eql_field(field_name)
return eql.ast.Comparison(field, operator.value, literal)
def field_value_expression(self, tree):
field_tree, value_tree = tree.child_trees
with self.scope(self.visit(field_tree)) as field_name:
# check the field against the schema
type_mapping = self.get_field_type(field_name, field_tree)
if type_mapping in NOT_SUPPORTED_EQL_FIELDS:
err_msg = f"{field_name} uses an unsupported elasticsearch eql field_type {type_mapping}"
raise eql.EqlSemanticError(err_msg, field_tree.line, field_tree.column, self.text)
return self.visit(value_tree)
def or_list_of_values(self, tree):
children = [self.visit(t) for t in tree.child_trees]
return eql.ast.Or(children)
def and_list_of_values(self, tree):
children = [self.visit(t) for t in tree.child_trees]
return eql.ast.And(children)
def not_list_of_values(self, tree):
return eql.ast.Not(self.visit(tree.children[-1]))
def field(self, tree):
literal = self.visit(tree.children[0])
return eql.utils.to_unicode(literal)
def value(self, tree):
# TODO: check the logic for kuery.peg
value = self.unescape_literal(tree.children[0])
if self.scoped_field is None:
raise self.error(tree, "Value not tied to field")
field_name = self.scoped_field
field = self.to_eql_field(field_name)
value = self.convert_value(field_name, value, tree)
value_ast = eql.ast.Literal.from_python(value)
if value is None:
return eql.ast.IsNull(field)
if eql.utils.is_string(value) and value.replace("*", "") == "":
return eql.ast.IsNotNull(field)
if eql.utils.is_string(value) and "*" in value:
return eql.ast.FunctionCall("wildcard", [field, value_ast])
if self.get_field_types(field_name) == {"ip"} and "/" in value:
return eql.ast.FunctionCall("cidrMatch", [field, value_ast])
return eql.ast.Comparison(field, "==", value_ast)
def literal(self, tree):
return self.unescape_literal(tree.children[0])
+130
View File
@@ -0,0 +1,130 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import functools
from eql import Walker, DepthFirstWalker
from .ast import AndValues, NotValue, Value, OrValues, NotExpr, FieldComparison
class Optimizer(DepthFirstWalker):
def flat_optimize(self, tree):
return Walker.walk(self, tree)
def _walk_default(self, tree, *args, **kwargs):
return tree
def group_fields(self, tree, value_cls): # type: (List, type) -> KqlNode
cls = type(tree)
field_groups = {}
ungrouped = []
for term in tree.items:
# move a `not` inwards before grouping
if isinstance(term, NotExpr) and isinstance(term.expr, FieldComparison):
term = FieldComparison(term.expr.field, NotValue(term.expr.value))
if isinstance(term, FieldComparison):
if term.field.name in field_groups:
existing_checks = field_groups[term.field.name]
existing_checks.append(term)
continue
else:
field_groups[term.field.name] = [term]
ungrouped.append(term)
for term in ungrouped:
if isinstance(term, FieldComparison):
term.value = self.flat_optimize(value_cls([t.value for t in field_groups[term.field.name]]))
ungrouped = [self.flat_optimize(u) for u in ungrouped]
return cls(ungrouped) if len(ungrouped) > 1 else ungrouped[0]
@staticmethod
def sort_key(a, b):
if isinstance(a, Value) and not isinstance(b, Value):
return -1
if not isinstance(a, Value) and isinstance(b, Value):
return +1
if isinstance(a, Value) and isinstance(b, Value):
t_a = type(a)
t_b = type(b)
if t_a == t_b:
return (a.value > b.value) - (a.value < b.value)
else:
return (t_a.__name__ > t_b.__name__) - (t_a.__name__ < t_b.__name__)
else:
# unable to compare
return 0
def _walk_field_comparison(self, tree): # type: (FieldComparison) -> KqlNode
# if there's a single `not`, then pull it out of the expression
if isinstance(tree.value, NotValue):
return NotExpr(FieldComparison(tree.field, tree.value.value))
return tree
def flatten(self, tree): # type: (List) -> List
cls = type(tree)
flattened = []
for node in tree.items:
if isinstance(node, cls):
flattened.extend(node.items)
else:
flattened.append(node)
flattened = [self.flat_optimize(t) for t in flattened]
return cls(flattened)
def flatten_values(self, tree, dual_cls): # type: (List, type) -> List
cls = type(tree)
flattened = []
not_term = None
for term in self.flatten(tree).items:
if isinstance(term, NotValue) and isinstance(term.value, Value):
# create a copy to leave the source tree unaltered
term = NotValue(term.value)
if not_term is None:
not_term = term
else:
not_term.value = dual_cls([not_term.value, term.value])
continue
flattened.append(term)
if not_term is not None:
not_term.value = self.flat_optimize(not_term.value)
flattened = [self.flat_optimize(t) for t in flattened]
flattened.sort(key=functools.cmp_to_key(self.sort_key))
return cls(flattened) if len(flattened) > 1 else flattened[0]
def _walk_not_value(self, tree): # type: (NotValue) -> KqlNode
if isinstance(tree.value, NotValue):
return tree.value.value
return tree
def _walk_or_values(self, tree):
return self.flatten_values(tree, AndValues)
def _walk_and_values(self, tree):
return self.flatten_values(tree, OrValues)
def _walk_not_expr(self, tree): # type: (NotExpr) -> KqlNode
if isinstance(tree.expr, NotExpr):
return tree.expr.expr
return tree
def _walk_and_expr(self, tree): # type: (AndExpr) -> KqlNode
return self.group_fields(self.flatten(tree), value_cls=AndValues)
def _walk_or_expr(self, tree): # type: (OrExpr) -> KqlNode
return self.group_fields(self.flatten(tree), value_cls=OrValues)
+377
View File
@@ -0,0 +1,377 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import contextlib
import os
import re
from typing import Optional, Set
import eql
from lark import Token # noqa: F401
from lark import Tree, Lark
from lark.exceptions import LarkError, UnexpectedEOF
from lark.visitors import Interpreter
from kql.errors import KqlParseError
from .ast import * # noqa: F403
STRING_FIELDS = ("keyword", "text")
class KvTree(Tree):
@property
def child_trees(self):
return [child for child in self.children if isinstance(child, KvTree)]
@property
def child_tokens(self):
return [child for child in self.children if isinstance(child, Token)]
grammar_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "kql.g")
with open(grammar_file, "rt") as f:
grammar = f.read()
lark_parser = Lark(grammar, propagate_positions=True, tree_class=KvTree, start=['query'], parser='lalr')
def is_ipaddress(value: str) -> bool:
"""Check if a value is an ip address."""
try:
eql.utils.get_ipaddress(value)
return True
except ValueError:
return False
def wildcard2regex(wc: str) -> re.Pattern:
parts = wc.split("*")
return re.compile("^{regex}$".format(regex=".*?".join(re.escape(w) for w in parts)))
def elasticsearch_type_family(mapping_type: str) -> str:
"""Get the family of type for an Elasticsearch mapping type."""
# https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-types.html
return {
# range types
"long_range": "range",
"double_range": "range",
"date_range": "range",
"ip_range": "range",
# text search types
"annotated-text": "text",
"completion": "text",
"match_only_text": "text",
"search-as_you_type": "text",
# keyword
"constant_keyword": "keyword",
"wildcard": "keyword",
# date
"date_nanos": "date",
# integer
"token_count": "integer",
"long": "integer",
"short": "integer",
"byte": "integer",
"unsigned_long": "integer",
# float
"double": "float",
"half_float": "float",
"scaled_float": "float",
}.get(mapping_type, mapping_type)
class BaseKqlParser(Interpreter):
NON_SPACE_WS = re.compile(r"[^\S ]+")
unquoted_escapes = {"\\t": "\t", "\\r": "\r", "\\n": "\n"}
for special in "\\():<>\"*{}]":
unquoted_escapes["\\" + special] = special
unquoted_regex = re.compile("(" + "|".join(re.escape(e) for e in sorted(unquoted_escapes)) + ")")
quoted_escapes = {"\\t": "\t", "\\r": "\r", "\\n": "\n", "\\\\": "\\", "\\\"": "\""}
quoted_regex = re.compile("(" + "|".join(re.escape(e) for e in sorted(quoted_escapes)) + ")")
def __init__(self, text, schema=None):
self.text = text
self.lines = [t.rstrip("\r\n") for t in self.text.splitlines(True)]
self.scoped_field = None
self.mapping_schema = schema
self.star_fields = []
if schema:
for field, field_type in schema.items():
if "*" in field:
self.star_fields.append(wildcard2regex(field))
def assert_lower_token(self, *tokens):
for token in tokens:
if str(token) != str(token).lower():
raise self.error(token, "Expected '{lower}' but got '{token}'".format(token=token, lower=str(token).lower()))
def error(self, node, message, end=False, cls=KqlParseError, width=None, **kwargs):
"""Generate an error exception but dont raise it."""
if kwargs:
message = message.format(**kwargs)
line_number = node.line - 1
column = node.column - 1
# get more lines for more informative error messages. three before + two after
before = self.lines[:line_number + 1][-3:]
after = self.lines[line_number + 1:][:3]
source = '\n'.join(b for b in before)
trailer = '\n'.join(a for a in after)
# Determine if the error message can easily look like this
# ^^^^
if width is None and not end and node.line == node.end_line:
if not self.NON_SPACE_WS.search(self.lines[line_number][column:node.end_column]):
width = node.end_column - node.column
if width is None:
width = 1
return cls(message, line_number, column, source, width=width, trailer=trailer)
def __default__(self, tree):
raise NotImplementedError("Unable to visit tree {} of type: {}".format(tree, tree.data))
def unescape_literal(self, token): # type: (Token) -> (int|float|str|bool)
if token.type == "QUOTED_STRING":
return self.convert_quoted_string(token.value)
else:
return self.convert_unquoted_literal(token.value)
@contextlib.contextmanager
def scope(self, field):
# with self.scope(field) as field:
# ...
self.scoped_field = field
yield field
self.scoped_field = None
def get_field_type(self, dotted_path, lark_tree=None):
matches_pattern = any(regex.match(dotted_path) for regex in self.star_fields)
if self.mapping_schema is not None:
if lark_tree is not None and dotted_path not in self.mapping_schema and not matches_pattern:
raise self.error(lark_tree, "Unknown field")
return self.mapping_schema.get(dotted_path)
def get_field_types(self, wildcard_dotted_path, lark_tree=None) -> Optional[Set[str]]:
if "*" not in wildcard_dotted_path:
field_type = self.get_field_type(wildcard_dotted_path, lark_tree=lark_tree)
return {field_type} if field_type is not None else None
if self.mapping_schema is not None:
regex = wildcard2regex(wildcard_dotted_path)
field_types = set()
for field, field_type in self.mapping_schema.items():
if regex.fullmatch(field) is not None:
field_types.add(field_type)
if len(field_types) == 0:
raise self.error(lark_tree, "Unknown field")
return field_types
@staticmethod
def get_literal_type(literal_value):
if isinstance(literal_value, bool):
return "boolean"
elif isinstance(literal_value, float):
return "float"
elif isinstance(literal_value, int):
return "long"
elif eql.utils.is_string(literal_value):
# this will be converted when compared to the field
return "keyword"
elif literal_value is None:
return "null"
else:
raise NotImplementedError("Unknown literal type: {}".format(type(literal_value).__name__))
def convert_value(self, field_name, python_value, value_tree):
field_type = None
field_types = self.get_field_types(field_name)
value_type = self.get_literal_type(python_value)
if field_types is not None:
if len(field_types) == 1:
field_type = list(field_types)[0]
elif len(field_types) > 1:
raise self.error(value_tree,
f"{field_name} has multiple types {', '.join(field_types)}")
if field_type is not None and field_type != value_type:
field_type_family = elasticsearch_type_family(field_type)
if field_type_family in STRING_FIELDS:
return eql.utils.to_unicode(python_value)
elif field_type_family in ("float", "integer"):
try:
return float(python_value) if field_type_family == "float" else int(python_value)
except ValueError:
pass
elif field_type_family == "ip" and value_type == "keyword":
if "::" in python_value or is_ipaddress(python_value) or eql.utils.is_cidr_pattern(python_value):
return python_value
elif field_type_family == 'date' and value_type in STRING_FIELDS:
# this will not validate datemath syntax
return python_value
raise self.error(value_tree, "Value doesn't match {field}'s type: {type}",
field=field_name, type=field_type)
# otherwise, there's nothing to convert
return python_value
@classmethod
def convert_unquoted_literal(cls, text):
if text == "true":
return True
elif text == "false":
return False
elif text == "null":
return None
else:
for numeric in (int, float):
try:
return numeric(text)
except ValueError:
pass
text = cls.unquoted_regex.sub(lambda r: cls.unquoted_escapes[r.group()], text)
return text
@classmethod
def convert_quoted_string(cls, text):
inner_text = text[1:-1]
unescaped = cls.quoted_regex.sub(lambda r: cls.quoted_escapes[r.group()], inner_text)
return unescaped
class KqlParser(BaseKqlParser):
def or_query(self, tree):
self.assert_lower_token(*tree.child_tokens)
terms = [self.visit(t) for t in tree.child_trees]
return OrExpr(terms)
def and_query(self, tree):
self.assert_lower_token(*tree.child_tokens)
terms = [self.visit(t) for t in tree.child_trees]
return AndExpr(terms)
def not_query(self, tree):
self.assert_lower_token(*tree.child_tokens)
return NotExpr(self.visit(tree.children[-1]))
@contextlib.contextmanager
def nest(self, lark_tree):
schema = self.mapping_schema
dotted_path = self.visit(lark_tree)
if self.get_field_type(dotted_path, lark_tree) != "nested":
raise self.error(lark_tree, "Expected a nested field")
try:
self.mapping_schema = self.mapping_schema[dotted_path]
yield
finally:
self.mapping_schema = schema
def nested_query(self, tree):
# field_tree, query_tree = tree.child_trees
#
# with self.nest(field_tree) as field:
# return NestedQuery(field, self.visit(query_tree))
raise self.error(tree, "Nested queries are not yet supported")
def field_value_expression(self, tree):
field_tree, expr = tree.child_trees
with self.scope(self.visit(field_tree)) as field:
# check the field against the schema
self.get_field_types(field.name, field_tree)
return FieldComparison(field, self.visit(expr))
def field_range_expression(self, tree):
field_tree, operator, literal = tree.children
with self.scope(self.visit(field_tree)) as field:
value = self.convert_value(field.name, self.visit(literal), literal)
return FieldRange(field, operator, Value.from_python(value))
def or_list_of_values(self, tree):
self.assert_lower_token(*tree.child_tokens)
return OrValues([self.visit(t) for t in tree.child_trees])
def and_list_of_values(self, tree):
self.assert_lower_token(*tree.child_tokens)
return AndValues([self.visit(t) for t in tree.child_trees])
def not_list_of_values(self, tree):
self.assert_lower_token(*tree.child_tokens)
return NotValue(self.visit(tree.children[-1]))
def literal(self, tree):
return self.unescape_literal(tree.children[0])
def field(self, tree):
literal = self.visit(tree.children[0])
return Field(eql.utils.to_unicode(literal))
def value(self, tree):
if self.scoped_field is None:
raise self.error(tree, "Value not tied to field")
field_name = self.scoped_field.name
token = tree.children[0]
value = self.unescape_literal(token)
if token.type == "UNQUOTED_LITERAL" and "*" in token.value:
field_type = self.get_field_type(field_name)
if len(value.replace("*", "")) == 0:
return Exists()
if field_type is not None and field_type not in ("keyword", "wildcard"):
raise self.error(tree, "Unable to perform wildcard on field {field} of {type}",
field=field_name, type=field_type)
return Wildcard(token.value)
# try to convert the value to the appropriate type
# example: 1 -> "1" if the field is actually keyword
value = self.convert_value(field_name, value, tree)
return Value.from_python(value)
def lark_parse(text):
if not text.strip():
raise KqlParseError("No query provided", 0, 0, "")
walker = BaseKqlParser(text)
try:
return lark_parser.parse(text)
except UnexpectedEOF:
raise KqlParseError("Unexpected EOF", len(walker.lines), len(walker.lines[-1].strip()), walker.lines[-1])
except LarkError as exc:
raise KqlParseError("Invalid syntax", exc.line - 1, exc.column - 1,
'\n'.join(walker.lines[exc.line - 2:exc.line]))
+31
View File
@@ -0,0 +1,31 @@
[project]
name = "detection-rules-kql"
version = "0.1.6"
description = "Kibana Query Language parser for Elastic Detection Rules"
license = {text = "Elastic License v2"}
keywords = ["Elastic", "sour", "Detection Rules", "Security", "Elasticsearch", "kql"]
classifiers = [
"Intended Audience :: Developers",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.12",
"Topic :: Security",
"Topic :: Software Development :: Build Tools",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Software Development",
]
requires-python = ">=3.12"
dependencies = [
"eql==0.9.19",
"lark-parser>=0.11.1",
]
[project.urls]
Homepage = "https://github.com/elastic/detection-rules"
License = "https://github.com/elastic/detection-rules/blob/main/LICENSE.txt"
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools.package-data]
kql = ["*.g"]