Separate out query validation from the class hierarchy (#1136)
* Separate out query validation from the class hierarchy * Rename to *RuleData for consistency * Apply suggestions from code review * Fix lint error
This commit is contained in:
@@ -25,7 +25,7 @@ from .eswrap import CollectEvents, add_range_to_dsl
|
||||
from .main import root
|
||||
from .misc import PYTHON_LICENSE, add_client, GithubClient, Manifest, client_error, getdefault
|
||||
from .packaging import PACKAGE_FILE, Package, manage_versions, RELEASE_DIR
|
||||
from .rule import TOMLRule, BaseQueryRuleData
|
||||
from .rule import TOMLRule, QueryRuleData
|
||||
from .rule_loader import production_filter, RuleCollection
|
||||
from .utils import get_path, dict_hash
|
||||
|
||||
@@ -389,7 +389,7 @@ def rule_event_search(ctx, rule, date_range, count, max_results, verbose,
|
||||
elasticsearch_client: Elasticsearch = None):
|
||||
"""Search using a rule file against an Elasticsearch instance."""
|
||||
|
||||
if isinstance(rule.contents.data, BaseQueryRuleData):
|
||||
if isinstance(rule.contents.data, QueryRuleData):
|
||||
if verbose:
|
||||
click.echo(f'Searching rule: {rule.name}')
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import click
|
||||
import yaml
|
||||
|
||||
from .misc import JS_LICENSE, cached
|
||||
from .rule import TOMLRule, BaseQueryRuleData, ThreatMapping
|
||||
from .rule import TOMLRule, QueryRuleData, ThreatMapping
|
||||
from .rule import downgrade_contents_from_rule
|
||||
from .rule_loader import RuleCollection, DEFAULT_RULES_DIR
|
||||
from .schemas import CurrentSchema, definitions
|
||||
@@ -377,7 +377,7 @@ class Package(object):
|
||||
def get_summary_rule_info(r: TOMLRule):
|
||||
r = r.contents
|
||||
rule_str = f'{r.name:<{longest_name}} (v:{r.autobumped_version} t:{r.data.type}'
|
||||
if isinstance(rule.contents.data, BaseQueryRuleData):
|
||||
if isinstance(rule.contents.data, QueryRuleData):
|
||||
rule_str += f'-{r.data.language}'
|
||||
rule_str += f'(indexes:{"".join(index_map[idx] for idx in rule.contents.data.index) or "none"}'
|
||||
|
||||
@@ -387,7 +387,7 @@ class Package(object):
|
||||
# lookup the rule in the GitHub tag v{major.minor.patch}
|
||||
data = r.contents.data
|
||||
rules_dir_link = f'https://github.com/elastic/detection-rules/tree/v{self.name}/rules/{sd}/'
|
||||
rule_type = data.language if isinstance(data, BaseQueryRuleData) else data.type
|
||||
rule_type = data.language if isinstance(data, QueryRuleData) else data.type
|
||||
return f'`{r.id}` **[{r.name}]({rules_dir_link + os.path.basename(str(r.path))})** (_{rule_type}_)'
|
||||
|
||||
for rule in self.rules:
|
||||
|
||||
+43
-108
@@ -5,15 +5,14 @@
|
||||
"""Rule object."""
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union, Optional, List, Any
|
||||
from uuid import uuid4
|
||||
|
||||
import eql
|
||||
from marshmallow import validates_schema
|
||||
|
||||
import kql
|
||||
from . import ecs, beats, utils
|
||||
from . import utils
|
||||
from .mixins import MarshmallowDataclassMixin
|
||||
from .rule_formatter import toml_write, nested_normalize
|
||||
from .schemas import definitions
|
||||
@@ -169,67 +168,48 @@ class BaseRuleData(MarshmallowDataclassMixin):
|
||||
type: Literal[definitions.RuleType]
|
||||
threat: Optional[List[ThreatMapping]]
|
||||
|
||||
def validate_query(self, meta: RuleMeta) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryValidator:
|
||||
query: str
|
||||
|
||||
@property
|
||||
def ast(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseQueryRuleData(BaseRuleData):
|
||||
class QueryRuleData(BaseRuleData):
|
||||
"""Specific fields for query event types."""
|
||||
type: Literal["query"]
|
||||
|
||||
index: Optional[List[str]]
|
||||
query: str
|
||||
language: str
|
||||
language: definitions.FilterLanguages
|
||||
|
||||
@property
|
||||
def parsed_query(self) -> Optional[object]:
|
||||
return None
|
||||
@cached_property
|
||||
def validator(self) -> Optional[QueryValidator]:
|
||||
if self.language == "kuery":
|
||||
return KQLValidator(self.query)
|
||||
elif self.language == "eql":
|
||||
return EQLValidator(self.query)
|
||||
|
||||
def validate_query(self, meta: RuleMeta) -> None:
|
||||
validator = self.validator
|
||||
if validator is not None:
|
||||
return validator.validate(self, meta)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KQLRuleData(BaseQueryRuleData):
|
||||
"""Specific fields for query event types."""
|
||||
language: Literal["kuery"]
|
||||
|
||||
@property
|
||||
def parsed_query(self) -> kql.ast.Expression:
|
||||
return kql.parse(self.query)
|
||||
|
||||
@property
|
||||
def unique_fields(self):
|
||||
return list(set(str(f) for f in self.parsed_query if isinstance(f, kql.ast.Field)))
|
||||
|
||||
def to_eql(self) -> eql.ast.Expression:
|
||||
return kql.to_eql(self.query)
|
||||
|
||||
def validate_query(self, beats_version: str, ecs_versions: List[str]):
|
||||
"""Static method to validate the query, called from the parent which contains [metadata] information."""
|
||||
indexes = self.index or []
|
||||
parsed = self.parsed_query
|
||||
|
||||
beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index]
|
||||
beat_schema = beats.get_schema_from_kql(parsed, beat_types, version=beats_version) if beat_types else None
|
||||
|
||||
if not ecs_versions:
|
||||
kql.parse(self.query, schema=ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema))
|
||||
else:
|
||||
for version in ecs_versions:
|
||||
schema = ecs.get_kql_schema(version=version, indexes=indexes, beat_schema=beat_schema)
|
||||
|
||||
try:
|
||||
kql.parse(self.query, schema=schema)
|
||||
except kql.KqlParseError as exc:
|
||||
message = exc.error_msg
|
||||
trailer = None
|
||||
if "Unknown field" in message and beat_types:
|
||||
trailer = "\nTry adding event.module or event.dataset to specify beats module"
|
||||
|
||||
raise kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source,
|
||||
len(exc.caret.lstrip()), trailer=trailer) from None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LuceneRuleData(BaseQueryRuleData):
|
||||
"""Specific fields for query event types."""
|
||||
language: Literal["lucene"]
|
||||
@cached_property
|
||||
def ast(self):
|
||||
validator = self.validator
|
||||
if validator is not None:
|
||||
return validator.ast
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -241,7 +221,7 @@ class MachineLearningRuleData(BaseRuleData):
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ThresholdQueryRuleData(BaseQueryRuleData):
|
||||
class ThresholdQueryRuleData(QueryRuleData):
|
||||
"""Specific fields for query event types."""
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -256,57 +236,18 @@ class ThresholdQueryRuleData(BaseQueryRuleData):
|
||||
cardinality: Optional[ThresholdCardinality]
|
||||
|
||||
type: Literal["threshold"]
|
||||
language: Literal["kuery", "lucene"]
|
||||
threshold: ThresholdMapping
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EQLRuleData(BaseQueryRuleData):
|
||||
class EQLRuleData(QueryRuleData):
|
||||
"""EQL rules are a special case of query rules."""
|
||||
type: Literal["eql"]
|
||||
|
||||
@property
|
||||
def parsed_query(self) -> kql.ast.Expression:
|
||||
with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
|
||||
return eql.parse_query(self.query)
|
||||
|
||||
@property
|
||||
def unique_fields(self):
|
||||
return list(set(str(f) for f in self.parsed_query if isinstance(f, eql.ast.Field)))
|
||||
|
||||
def validate_query(self, beats_version: str, ecs_versions: List[str]):
|
||||
"""Validate an EQL query while checking TOMLRule."""
|
||||
# TODO: remove once py-eql supports ipv6 for cidrmatch
|
||||
# Or, unregister the cidrMatch function and replace it with one that doesn't validate against strict IPv4
|
||||
with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
|
||||
parsed = eql.parse_query(self.query)
|
||||
|
||||
beat_types = [index.split("-")[0] for index in self.index or [] if "beat-*" in index]
|
||||
beat_schema = beats.get_schema_from_eql(parsed, beat_types, version=beats_version) if beat_types else None
|
||||
|
||||
for version in ecs_versions:
|
||||
schema = ecs.get_kql_schema(indexes=self.index or [], beat_schema=beat_schema, version=version)
|
||||
|
||||
try:
|
||||
# TODO: switch to custom cidrmatch that allows ipv6
|
||||
with ecs.KqlSchema2Eql(schema), eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
|
||||
eql.parse_query(self.query)
|
||||
|
||||
except eql.EqlTypeMismatchError:
|
||||
raise
|
||||
|
||||
except eql.EqlParseError as exc:
|
||||
message = exc.error_msg
|
||||
trailer = None
|
||||
if "Unknown field" in message and beat_types:
|
||||
trailer = "\nTry adding event.module or event.dataset to specify beats module"
|
||||
|
||||
raise exc.__class__(exc.error_msg, exc.line, exc.column, exc.source,
|
||||
len(exc.caret.lstrip()), trailer=trailer) from None
|
||||
language: Literal["eql"]
|
||||
|
||||
|
||||
# All of the possible rule types
|
||||
AnyRuleData = Union[KQLRuleData, LuceneRuleData, MachineLearningRuleData, ThresholdQueryRuleData, EQLRuleData]
|
||||
AnyRuleData = Union[QueryRuleData, EQLRuleData, MachineLearningRuleData, ThresholdQueryRuleData]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -365,17 +306,7 @@ class TOMLRuleContents(MarshmallowDataclassMixin):
|
||||
data: AnyRuleData = value["data"]
|
||||
metadata: RuleMeta = value["metadata"]
|
||||
|
||||
beats_version = metadata.beats_version or beats.get_max_version()
|
||||
ecs_versions = metadata.ecs_versions or [ecs.get_max_version()]
|
||||
|
||||
# call into these validate methods
|
||||
if isinstance(data, (EQLRuleData, KQLRuleData)):
|
||||
if metadata.query_schema_validation is False or metadata.maturity == "deprecated":
|
||||
# Check the syntax only
|
||||
_ = data.parsed_query
|
||||
else:
|
||||
# otherwise, do a full schema validation
|
||||
data.validate_query(beats_version=beats_version, ecs_versions=ecs_versions)
|
||||
return data.validate_query(metadata)
|
||||
|
||||
def to_dict(self, strip_none_values=True) -> dict:
|
||||
dict_obj = super(TOMLRuleContents, self).to_dict(strip_none_values=strip_none_values)
|
||||
@@ -454,3 +385,7 @@ def downgrade_contents_from_rule(rule: TOMLRule, target_version: str) -> dict:
|
||||
payload["rule_id"] = str(uuid4())
|
||||
payload = downgrade(payload, target_version)
|
||||
return payload
|
||||
|
||||
|
||||
# avoid a circular import
|
||||
from .rule_validators import KQLValidator, EQLValidator # noqa: E402
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
# or more contributor license agreements. Licensed under the Elastic License
|
||||
# 2.0; you may not use this file except in compliance with the Elastic License
|
||||
# 2.0.
|
||||
|
||||
"""Validation logic for rules containing queries."""
|
||||
from functools import cached_property
|
||||
from typing import List
|
||||
|
||||
import eql
|
||||
|
||||
import kql
|
||||
from detection_rules import beats, ecs
|
||||
from detection_rules.rule import QueryValidator, QueryRuleData, RuleMeta
|
||||
|
||||
|
||||
class KQLValidator(QueryValidator):
|
||||
"""Specific fields for query event types."""
|
||||
|
||||
@cached_property
|
||||
def ast(self) -> kql.ast.Expression:
|
||||
return kql.parse(self.query)
|
||||
|
||||
@property
|
||||
def unique_fields(self) -> List[str]:
|
||||
return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field)))
|
||||
|
||||
def to_eql(self) -> eql.ast.Expression:
|
||||
return kql.to_eql(self.query)
|
||||
|
||||
def validate(self, data: QueryRuleData, meta: RuleMeta) -> None:
|
||||
"""Static method to validate the query, called from the parent which contains [metadata] information."""
|
||||
ast = self.ast
|
||||
|
||||
if meta.query_schema_validation is False or meta.maturity == "deprecated":
|
||||
# syntax only, which is done via self.ast
|
||||
return
|
||||
|
||||
indexes = data.index or []
|
||||
beats_version = meta.beats_version or beats.get_max_version()
|
||||
ecs_versions = meta.ecs_versions or [ecs.get_max_version()]
|
||||
|
||||
beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index]
|
||||
beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None
|
||||
|
||||
if not ecs_versions:
|
||||
kql.parse(self.query, schema=ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema))
|
||||
else:
|
||||
for version in ecs_versions:
|
||||
schema = ecs.get_kql_schema(version=version, indexes=indexes, beat_schema=beat_schema)
|
||||
|
||||
try:
|
||||
kql.parse(self.query, schema=schema)
|
||||
except kql.KqlParseError as exc:
|
||||
message = exc.error_msg
|
||||
trailer = None
|
||||
if "Unknown field" in message and beat_types:
|
||||
trailer = "\nTry adding event.module or event.dataset to specify beats module"
|
||||
|
||||
raise kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source,
|
||||
len(exc.caret.lstrip()), trailer=trailer) from None
|
||||
|
||||
|
||||
class EQLValidator(QueryValidator):
|
||||
|
||||
@cached_property
|
||||
def ast(self) -> kql.ast.Expression:
|
||||
with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
|
||||
return eql.parse_query(self.query)
|
||||
|
||||
@property
|
||||
def unique_fields(self) -> List[str]:
|
||||
return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field)))
|
||||
|
||||
def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
|
||||
"""Validate an EQL query while checking TOMLRule."""
|
||||
_ = self.ast
|
||||
|
||||
if meta.query_schema_validation is False or meta.maturity == "deprecated":
|
||||
# syntax only, which is done via self.ast
|
||||
return
|
||||
|
||||
indexes = data.index or []
|
||||
beats_version = meta.beats_version or beats.get_max_version()
|
||||
ecs_versions = meta.ecs_versions or [ecs.get_max_version()]
|
||||
|
||||
# TODO: remove once py-eql supports ipv6 for cidrmatch
|
||||
# Or, unregister the cidrMatch function and replace it with one that doesn't validate against strict IPv4
|
||||
with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
|
||||
parsed = eql.parse_query(self.query)
|
||||
|
||||
beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index]
|
||||
beat_schema = beats.get_schema_from_eql(parsed, beat_types, version=beats_version) if beat_types else None
|
||||
|
||||
for version in ecs_versions:
|
||||
schema = ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema, version=version)
|
||||
eql_schema = ecs.KqlSchema2Eql(schema)
|
||||
|
||||
try:
|
||||
# TODO: switch to custom cidrmatch that allows ipv6
|
||||
with eql_schema, eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
|
||||
eql.parse_query(self.query)
|
||||
|
||||
except eql.EqlTypeMismatchError:
|
||||
raise
|
||||
|
||||
except eql.EqlParseError as exc:
|
||||
message = exc.error_msg
|
||||
trailer = None
|
||||
if "Unknown field" in message and beat_types:
|
||||
trailer = "\nTry adding event.module or event.dataset to specify beats module"
|
||||
|
||||
raise exc.__class__(exc.error_msg, exc.line, exc.column, exc.source,
|
||||
len(exc.caret.lstrip()), trailer=trailer) from None
|
||||
@@ -40,6 +40,7 @@ CodeString = NewType("CodeString", str)
|
||||
ConditionSemVer = NewType('ConditionSemVer', str, validate=validate.Regexp(CONDITION_VERSION_PATTERN))
|
||||
Date = NewType('Date', str, validate=validate.Regexp(DATE_PATTERN))
|
||||
Interval = NewType('Interval', str, validate=validate.Regexp(INTERVAL_PATTERN))
|
||||
FilterLanguages = Literal["kuery", "lucene"]
|
||||
Markdown = NewType("MarkdownField", CodeString)
|
||||
Maturity = Literal['development', 'experimental', 'beta', 'production', 'deprecated']
|
||||
MaxSignals = NewType("MaxSignals", int, validate=validate.Range(min=1))
|
||||
|
||||
@@ -14,7 +14,7 @@ import eql
|
||||
import kql
|
||||
from detection_rules import attack, beats, ecs
|
||||
from detection_rules.packaging import load_versions
|
||||
from detection_rules.rule import BaseQueryRuleData
|
||||
from detection_rules.rule import QueryRuleData
|
||||
from detection_rules.rule_loader import FILE_PATTERN
|
||||
from detection_rules.utils import get_path, load_etc_dump
|
||||
from rta import get_ttp_names
|
||||
@@ -58,7 +58,7 @@ class TestValidRules(BaseRuleTest):
|
||||
ttp_names = get_ttp_names()
|
||||
|
||||
for rule in self.production_rules:
|
||||
if isinstance(rule.contents.data, BaseQueryRuleData) and rule.id in mappings:
|
||||
if isinstance(rule.contents.data, QueryRuleData) and rule.id in mappings:
|
||||
matching_rta = mappings[rule.id].get('rta_name')
|
||||
|
||||
self.assertIsNotNone(matching_rta, f'{self.rule_str(rule)} does not have RTAs')
|
||||
@@ -232,7 +232,7 @@ class TestRuleTags(BaseRuleTest):
|
||||
if 'Elastic' not in rule_tags:
|
||||
missing_required_tags.add('Elastic')
|
||||
|
||||
if isinstance(rule.contents.data, BaseQueryRuleData):
|
||||
if isinstance(rule.contents.data, QueryRuleData):
|
||||
for index in rule.contents.data.index:
|
||||
expected_tags = required_tags_map.get(index, {})
|
||||
expected_all = expected_tags.get('all', [])
|
||||
@@ -440,13 +440,13 @@ class TestTuleTiming(BaseRuleTest):
|
||||
for rule in self.all_rules:
|
||||
required = False
|
||||
|
||||
if isinstance(rule.contents.data, BaseQueryRuleData) and 'endgame-*' in rule.contents.data.index:
|
||||
if isinstance(rule.contents.data, QueryRuleData) and 'endgame-*' in rule.contents.data.index:
|
||||
continue
|
||||
|
||||
if rule.contents.data.type == 'query':
|
||||
required = True
|
||||
elif rule.contents.data.type == 'eql' and \
|
||||
eql.utils.get_query_type(rule.contents.data.parsed_query) != 'sequence':
|
||||
eql.utils.get_query_type(rule.contents.data.ast) != 'sequence':
|
||||
required = True
|
||||
|
||||
if required and rule.contents.data.timestamp_override != 'event.ingested':
|
||||
@@ -465,7 +465,7 @@ class TestTuleTiming(BaseRuleTest):
|
||||
for rule in self.all_rules:
|
||||
contents = rule.contents
|
||||
|
||||
if isinstance(contents.data, BaseQueryRuleData):
|
||||
if isinstance(contents.data, QueryRuleData):
|
||||
if set(getattr(contents.data, "index", None) or []) & long_indexes and not contents.data.from_:
|
||||
missing.append(rule)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
import copy
|
||||
import warnings
|
||||
|
||||
from detection_rules.rule import KQLRuleData
|
||||
from . import get_data_files, get_fp_data_files
|
||||
from detection_rules.utils import combine_sources, evaluate, load_etc_dump
|
||||
from .base import BaseRuleTest
|
||||
@@ -30,7 +29,7 @@ class TestMappings(BaseRuleTest):
|
||||
mappings = load_etc_dump('rule-mapping.yml')
|
||||
|
||||
for rule in self.production_rules:
|
||||
if isinstance(rule.contents.data, KQLRuleData):
|
||||
if rule.contents.data.type == "query" and rule.contents.data.language == "kuery":
|
||||
if rule.id not in mappings:
|
||||
continue
|
||||
|
||||
@@ -63,7 +62,7 @@ class TestMappings(BaseRuleTest):
|
||||
def test_false_positives(self):
|
||||
"""Test that expected results return against false positives."""
|
||||
for rule in self.production_rules:
|
||||
if isinstance(rule.contents.data, KQLRuleData):
|
||||
if rule.contents.data.type == "query" and rule.contents.data.language == "kuery":
|
||||
for fp_name, merged_data in get_fp_data_files().items():
|
||||
msg = 'Unexpected FP match for: {} - {}, against: {}'.format(rule.id, rule.name, fp_name)
|
||||
self.evaluate(copy.deepcopy(merged_data), rule, 0, msg)
|
||||
|
||||
Reference in New Issue
Block a user