Separate out query validation from the class hierarchy (#1136)

* Separate out query validation from the class hierarchy
* Rename to *RuleData for consistency
* Apply suggestions from code review
* Fix lint error
This commit is contained in:
Ross Wolf
2021-04-21 14:55:26 -06:00
committed by GitHub
parent ff45539369
commit 8789dd7c90
7 changed files with 171 additions and 122 deletions
+2 -2
View File
@@ -25,7 +25,7 @@ from .eswrap import CollectEvents, add_range_to_dsl
from .main import root
from .misc import PYTHON_LICENSE, add_client, GithubClient, Manifest, client_error, getdefault
from .packaging import PACKAGE_FILE, Package, manage_versions, RELEASE_DIR
from .rule import TOMLRule, BaseQueryRuleData
from .rule import TOMLRule, QueryRuleData
from .rule_loader import production_filter, RuleCollection
from .utils import get_path, dict_hash
@@ -389,7 +389,7 @@ def rule_event_search(ctx, rule, date_range, count, max_results, verbose,
elasticsearch_client: Elasticsearch = None):
"""Search using a rule file against an Elasticsearch instance."""
if isinstance(rule.contents.data, BaseQueryRuleData):
if isinstance(rule.contents.data, QueryRuleData):
if verbose:
click.echo(f'Searching rule: {rule.name}')
+3 -3
View File
@@ -18,7 +18,7 @@ import click
import yaml
from .misc import JS_LICENSE, cached
from .rule import TOMLRule, BaseQueryRuleData, ThreatMapping
from .rule import TOMLRule, QueryRuleData, ThreatMapping
from .rule import downgrade_contents_from_rule
from .rule_loader import RuleCollection, DEFAULT_RULES_DIR
from .schemas import CurrentSchema, definitions
@@ -377,7 +377,7 @@ class Package(object):
def get_summary_rule_info(r: TOMLRule):
r = r.contents
rule_str = f'{r.name:<{longest_name}} (v:{r.autobumped_version} t:{r.data.type}'
if isinstance(rule.contents.data, BaseQueryRuleData):
if isinstance(rule.contents.data, QueryRuleData):
rule_str += f'-{r.data.language}'
rule_str += f'(indexes:{"".join(index_map[idx] for idx in rule.contents.data.index) or "none"}'
@@ -387,7 +387,7 @@ class Package(object):
# lookup the rule in the GitHub tag v{major.minor.patch}
data = r.contents.data
rules_dir_link = f'https://github.com/elastic/detection-rules/tree/v{self.name}/rules/{sd}/'
rule_type = data.language if isinstance(data, BaseQueryRuleData) else data.type
rule_type = data.language if isinstance(data, QueryRuleData) else data.type
return f'`{r.id}` **[{r.name}]({rules_dir_link + os.path.basename(str(r.path))})** (_{rule_type}_)'
for rule in self.rules:
+43 -108
View File
@@ -5,15 +5,14 @@
"""Rule object."""
import json
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import Literal, Union, Optional, List, Any
from uuid import uuid4
import eql
from marshmallow import validates_schema
import kql
from . import ecs, beats, utils
from . import utils
from .mixins import MarshmallowDataclassMixin
from .rule_formatter import toml_write, nested_normalize
from .schemas import definitions
@@ -169,67 +168,48 @@ class BaseRuleData(MarshmallowDataclassMixin):
type: Literal[definitions.RuleType]
threat: Optional[List[ThreatMapping]]
def validate_query(self, meta: RuleMeta) -> None:
pass
@dataclass
class QueryValidator:
query: str
@property
def ast(self) -> Any:
raise NotImplementedError
def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
raise NotImplementedError()
@dataclass(frozen=True)
class BaseQueryRuleData(BaseRuleData):
class QueryRuleData(BaseRuleData):
"""Specific fields for query event types."""
type: Literal["query"]
index: Optional[List[str]]
query: str
language: str
language: definitions.FilterLanguages
@property
def parsed_query(self) -> Optional[object]:
return None
@cached_property
def validator(self) -> Optional[QueryValidator]:
if self.language == "kuery":
return KQLValidator(self.query)
elif self.language == "eql":
return EQLValidator(self.query)
def validate_query(self, meta: RuleMeta) -> None:
validator = self.validator
if validator is not None:
return validator.validate(self, meta)
@dataclass(frozen=True)
class KQLRuleData(BaseQueryRuleData):
"""Specific fields for query event types."""
language: Literal["kuery"]
@property
def parsed_query(self) -> kql.ast.Expression:
return kql.parse(self.query)
@property
def unique_fields(self):
return list(set(str(f) for f in self.parsed_query if isinstance(f, kql.ast.Field)))
def to_eql(self) -> eql.ast.Expression:
return kql.to_eql(self.query)
def validate_query(self, beats_version: str, ecs_versions: List[str]):
"""Static method to validate the query, called from the parent which contains [metadata] information."""
indexes = self.index or []
parsed = self.parsed_query
beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index]
beat_schema = beats.get_schema_from_kql(parsed, beat_types, version=beats_version) if beat_types else None
if not ecs_versions:
kql.parse(self.query, schema=ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema))
else:
for version in ecs_versions:
schema = ecs.get_kql_schema(version=version, indexes=indexes, beat_schema=beat_schema)
try:
kql.parse(self.query, schema=schema)
except kql.KqlParseError as exc:
message = exc.error_msg
trailer = None
if "Unknown field" in message and beat_types:
trailer = "\nTry adding event.module or event.dataset to specify beats module"
raise kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source,
len(exc.caret.lstrip()), trailer=trailer) from None
@dataclass(frozen=True)
class LuceneRuleData(BaseQueryRuleData):
"""Specific fields for query event types."""
language: Literal["lucene"]
@cached_property
def ast(self):
validator = self.validator
if validator is not None:
return validator.ast
@dataclass(frozen=True)
@@ -241,7 +221,7 @@ class MachineLearningRuleData(BaseRuleData):
@dataclass(frozen=True)
class ThresholdQueryRuleData(BaseQueryRuleData):
class ThresholdQueryRuleData(QueryRuleData):
"""Specific fields for query event types."""
@dataclass(frozen=True)
@@ -256,57 +236,18 @@ class ThresholdQueryRuleData(BaseQueryRuleData):
cardinality: Optional[ThresholdCardinality]
type: Literal["threshold"]
language: Literal["kuery", "lucene"]
threshold: ThresholdMapping
@dataclass(frozen=True)
class EQLRuleData(BaseQueryRuleData):
class EQLRuleData(QueryRuleData):
"""EQL rules are a special case of query rules."""
type: Literal["eql"]
@property
def parsed_query(self) -> kql.ast.Expression:
with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
return eql.parse_query(self.query)
@property
def unique_fields(self):
return list(set(str(f) for f in self.parsed_query if isinstance(f, eql.ast.Field)))
def validate_query(self, beats_version: str, ecs_versions: List[str]):
"""Validate an EQL query while checking TOMLRule."""
# TODO: remove once py-eql supports ipv6 for cidrmatch
# Or, unregister the cidrMatch function and replace it with one that doesn't validate against strict IPv4
with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
parsed = eql.parse_query(self.query)
beat_types = [index.split("-")[0] for index in self.index or [] if "beat-*" in index]
beat_schema = beats.get_schema_from_eql(parsed, beat_types, version=beats_version) if beat_types else None
for version in ecs_versions:
schema = ecs.get_kql_schema(indexes=self.index or [], beat_schema=beat_schema, version=version)
try:
# TODO: switch to custom cidrmatch that allows ipv6
with ecs.KqlSchema2Eql(schema), eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
eql.parse_query(self.query)
except eql.EqlTypeMismatchError:
raise
except eql.EqlParseError as exc:
message = exc.error_msg
trailer = None
if "Unknown field" in message and beat_types:
trailer = "\nTry adding event.module or event.dataset to specify beats module"
raise exc.__class__(exc.error_msg, exc.line, exc.column, exc.source,
len(exc.caret.lstrip()), trailer=trailer) from None
language: Literal["eql"]
# All of the possible rule types
AnyRuleData = Union[KQLRuleData, LuceneRuleData, MachineLearningRuleData, ThresholdQueryRuleData, EQLRuleData]
AnyRuleData = Union[QueryRuleData, EQLRuleData, MachineLearningRuleData, ThresholdQueryRuleData]
@dataclass(frozen=True)
@@ -365,17 +306,7 @@ class TOMLRuleContents(MarshmallowDataclassMixin):
data: AnyRuleData = value["data"]
metadata: RuleMeta = value["metadata"]
beats_version = metadata.beats_version or beats.get_max_version()
ecs_versions = metadata.ecs_versions or [ecs.get_max_version()]
# call into these validate methods
if isinstance(data, (EQLRuleData, KQLRuleData)):
if metadata.query_schema_validation is False or metadata.maturity == "deprecated":
# Check the syntax only
_ = data.parsed_query
else:
# otherwise, do a full schema validation
data.validate_query(beats_version=beats_version, ecs_versions=ecs_versions)
return data.validate_query(metadata)
def to_dict(self, strip_none_values=True) -> dict:
dict_obj = super(TOMLRuleContents, self).to_dict(strip_none_values=strip_none_values)
@@ -454,3 +385,7 @@ def downgrade_contents_from_rule(rule: TOMLRule, target_version: str) -> dict:
payload["rule_id"] = str(uuid4())
payload = downgrade(payload, target_version)
return payload
# avoid a circular import
from .rule_validators import KQLValidator, EQLValidator # noqa: E402
+114
View File
@@ -0,0 +1,114 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
"""Validation logic for rules containing queries."""
from functools import cached_property
from typing import List
import eql
import kql
from detection_rules import beats, ecs
from detection_rules.rule import QueryValidator, QueryRuleData, RuleMeta
class KQLValidator(QueryValidator):
"""Specific fields for query event types."""
@cached_property
def ast(self) -> kql.ast.Expression:
return kql.parse(self.query)
@property
def unique_fields(self) -> List[str]:
return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field)))
def to_eql(self) -> eql.ast.Expression:
return kql.to_eql(self.query)
def validate(self, data: QueryRuleData, meta: RuleMeta) -> None:
"""Static method to validate the query, called from the parent which contains [metadata] information."""
ast = self.ast
if meta.query_schema_validation is False or meta.maturity == "deprecated":
# syntax only, which is done via self.ast
return
indexes = data.index or []
beats_version = meta.beats_version or beats.get_max_version()
ecs_versions = meta.ecs_versions or [ecs.get_max_version()]
beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index]
beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None
if not ecs_versions:
kql.parse(self.query, schema=ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema))
else:
for version in ecs_versions:
schema = ecs.get_kql_schema(version=version, indexes=indexes, beat_schema=beat_schema)
try:
kql.parse(self.query, schema=schema)
except kql.KqlParseError as exc:
message = exc.error_msg
trailer = None
if "Unknown field" in message and beat_types:
trailer = "\nTry adding event.module or event.dataset to specify beats module"
raise kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source,
len(exc.caret.lstrip()), trailer=trailer) from None
class EQLValidator(QueryValidator):
@cached_property
def ast(self) -> kql.ast.Expression:
with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
return eql.parse_query(self.query)
@property
def unique_fields(self) -> List[str]:
return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field)))
def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
"""Validate an EQL query while checking TOMLRule."""
_ = self.ast
if meta.query_schema_validation is False or meta.maturity == "deprecated":
# syntax only, which is done via self.ast
return
indexes = data.index or []
beats_version = meta.beats_version or beats.get_max_version()
ecs_versions = meta.ecs_versions or [ecs.get_max_version()]
# TODO: remove once py-eql supports ipv6 for cidrmatch
# Or, unregister the cidrMatch function and replace it with one that doesn't validate against strict IPv4
with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
parsed = eql.parse_query(self.query)
beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index]
beat_schema = beats.get_schema_from_eql(parsed, beat_types, version=beats_version) if beat_types else None
for version in ecs_versions:
schema = ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema, version=version)
eql_schema = ecs.KqlSchema2Eql(schema)
try:
# TODO: switch to custom cidrmatch that allows ipv6
with eql_schema, eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
eql.parse_query(self.query)
except eql.EqlTypeMismatchError:
raise
except eql.EqlParseError as exc:
message = exc.error_msg
trailer = None
if "Unknown field" in message and beat_types:
trailer = "\nTry adding event.module or event.dataset to specify beats module"
raise exc.__class__(exc.error_msg, exc.line, exc.column, exc.source,
len(exc.caret.lstrip()), trailer=trailer) from None
+1
View File
@@ -40,6 +40,7 @@ CodeString = NewType("CodeString", str)
ConditionSemVer = NewType('ConditionSemVer', str, validate=validate.Regexp(CONDITION_VERSION_PATTERN))
Date = NewType('Date', str, validate=validate.Regexp(DATE_PATTERN))
Interval = NewType('Interval', str, validate=validate.Regexp(INTERVAL_PATTERN))
FilterLanguages = Literal["kuery", "lucene"]
Markdown = NewType("MarkdownField", CodeString)
Maturity = Literal['development', 'experimental', 'beta', 'production', 'deprecated']
MaxSignals = NewType("MaxSignals", int, validate=validate.Range(min=1))
+6 -6
View File
@@ -14,7 +14,7 @@ import eql
import kql
from detection_rules import attack, beats, ecs
from detection_rules.packaging import load_versions
from detection_rules.rule import BaseQueryRuleData
from detection_rules.rule import QueryRuleData
from detection_rules.rule_loader import FILE_PATTERN
from detection_rules.utils import get_path, load_etc_dump
from rta import get_ttp_names
@@ -58,7 +58,7 @@ class TestValidRules(BaseRuleTest):
ttp_names = get_ttp_names()
for rule in self.production_rules:
if isinstance(rule.contents.data, BaseQueryRuleData) and rule.id in mappings:
if isinstance(rule.contents.data, QueryRuleData) and rule.id in mappings:
matching_rta = mappings[rule.id].get('rta_name')
self.assertIsNotNone(matching_rta, f'{self.rule_str(rule)} does not have RTAs')
@@ -232,7 +232,7 @@ class TestRuleTags(BaseRuleTest):
if 'Elastic' not in rule_tags:
missing_required_tags.add('Elastic')
if isinstance(rule.contents.data, BaseQueryRuleData):
if isinstance(rule.contents.data, QueryRuleData):
for index in rule.contents.data.index:
expected_tags = required_tags_map.get(index, {})
expected_all = expected_tags.get('all', [])
@@ -440,13 +440,13 @@ class TestTuleTiming(BaseRuleTest):
for rule in self.all_rules:
required = False
if isinstance(rule.contents.data, BaseQueryRuleData) and 'endgame-*' in rule.contents.data.index:
if isinstance(rule.contents.data, QueryRuleData) and 'endgame-*' in rule.contents.data.index:
continue
if rule.contents.data.type == 'query':
required = True
elif rule.contents.data.type == 'eql' and \
eql.utils.get_query_type(rule.contents.data.parsed_query) != 'sequence':
eql.utils.get_query_type(rule.contents.data.ast) != 'sequence':
required = True
if required and rule.contents.data.timestamp_override != 'event.ingested':
@@ -465,7 +465,7 @@ class TestTuleTiming(BaseRuleTest):
for rule in self.all_rules:
contents = rule.contents
if isinstance(contents.data, BaseQueryRuleData):
if isinstance(contents.data, QueryRuleData):
if set(getattr(contents.data, "index", None) or []) & long_indexes and not contents.data.from_:
missing.append(rule)
+2 -3
View File
@@ -7,7 +7,6 @@
import copy
import warnings
from detection_rules.rule import KQLRuleData
from . import get_data_files, get_fp_data_files
from detection_rules.utils import combine_sources, evaluate, load_etc_dump
from .base import BaseRuleTest
@@ -30,7 +29,7 @@ class TestMappings(BaseRuleTest):
mappings = load_etc_dump('rule-mapping.yml')
for rule in self.production_rules:
if isinstance(rule.contents.data, KQLRuleData):
if rule.contents.data.type == "query" and rule.contents.data.language == "kuery":
if rule.id not in mappings:
continue
@@ -63,7 +62,7 @@ class TestMappings(BaseRuleTest):
def test_false_positives(self):
"""Test that expected results return against false positives."""
for rule in self.production_rules:
if isinstance(rule.contents.data, KQLRuleData):
if rule.contents.data.type == "query" and rule.contents.data.language == "kuery":
for fp_name, merged_data in get_fp_data_files().items():
msg = 'Unexpected FP match for: {} - {}, against: {}'.format(rule.id, rule.name, fp_name)
self.evaluate(copy.deepcopy(merged_data), rule, 0, msg)