[Bug] Query validation failing to capture InSet edge case with ip field types (#3572)

* Move test case to separate file

---------

Co-authored-by: Mika Ayenson <Mikaayenson@users.noreply.github.com>
Co-authored-by: shashank-elastic <91139415+shashank-elastic@users.noreply.github.com>
This commit is contained in:
Eric Forte
2024-05-06 07:58:42 -04:00
committed by GitHub
parent 51268581a8
commit a4a0bc6a7e
3 changed files with 214 additions and 49 deletions
+73 -2
View File
@@ -4,10 +4,14 @@
# 2.0.
"""Validation logic for rules containing queries."""
from functools import cached_property
from typing import List, Optional, Tuple, Union
from enum import Enum
from functools import cached_property, wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import eql
from eql import ast
from eql.parser import KvTree, LarkToEQL, NodeInfo, TypeHint
from eql.parser import _parse as base_parse
from marshmallow import ValidationError
from semver import Version
@@ -31,6 +35,73 @@ EQL_ERROR_TYPES = Union[eql.EqlCompileError,
KQL_ERROR_TYPES = Union[kql.KqlCompileError, kql.KqlParseError]
class ExtendedTypeHint(Enum):
IP = "ip"
@classmethod
def primitives(cls):
"""Get all primitive types."""
return TypeHint.Boolean, TypeHint.Numeric, TypeHint.Null, TypeHint.String, ExtendedTypeHint.IP
def is_primitive(self):
"""Check if a type is a primitive."""
return self in self.primitives()
def custom_in_set(self, node: KvTree) -> NodeInfo:
"""Override and address the limitations of the eql in_set method."""
# return BaseInSetMethod(self, node)
outer, container = self.visit(node.child_trees) # type: (NodeInfo, list[NodeInfo])
if not outer.validate_type(ExtendedTypeHint.primitives()):
# can't compare non-primitives to sets
raise self._type_error(outer, ExtendedTypeHint.primitives())
# Check that everything inside the container has the same type as outside
error_message = "Unable to compare {expected_type} to {actual_type}"
for inner in container:
if not inner.validate_type(outer):
raise self._type_error(inner, outer, error_message)
if self._elasticsearch_syntax and hasattr(outer, "type_info"):
# Check edge case of in_set and ip/string comparison
outer_type = outer.type_info
if isinstance(self._schema, ecs.KqlSchema2Eql):
type_hint = self._schema.kql_schema.get(str(outer.node), "unknown")
if hasattr(self._schema, "type_mapping") and type_hint == "ip":
outer.type_info = ExtendedTypeHint.IP
for inner in container:
if not inner.validate_type(outer):
raise self._type_error(inner, outer, error_message)
# reset the type
outer.type_info = outer_type
# This will always evaluate to true/false, so it should be a boolean
term = ast.InSet(outer.node, [c.node for c in container])
nullable = outer.nullable or any(c.nullable for c in container)
return NodeInfo(term, TypeHint.Boolean, nullable=nullable, source=node)
def custom_base_parse_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
"""Override and address the limitations of the eql in_set method."""
@wraps(func)
def wrapper(query: str, start: Optional[str] = None, **kwargs: Dict[str, Any]) -> Any:
original_in_set = LarkToEQL.in_set
LarkToEQL.in_set = custom_in_set
try:
result = func(query, start=start, **kwargs)
finally: # Using finally to ensure that the original method is restored
LarkToEQL.in_set = original_in_set
return result
return wrapper
eql.parser._parse = custom_base_parse_decorator(base_parse)
class KQLValidator(QueryValidator):
"""Specific fields for KQL query event types."""
+68
View File
@@ -0,0 +1,68 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
from detection_rules.rule_loader import RuleCollection
from .base import BaseRuleTest
class TestEQLInSet(BaseRuleTest):
"""Test EQL rule query in set override."""
def test_eql_in_set(self):
"""Test that the query validation is working correctly."""
rc = RuleCollection()
eql_rule = {
"metadata": {
"creation_date": "2020/12/15",
"integration": ["endpoint", "windows"],
"maturity": "production",
"min_stack_comments": "New fields added: required_fields, related_integrations, setup",
"min_stack_version": "8.3.0",
"updated_date": "2024/03/26",
},
"rule": {
"author": ["Elastic"],
"description": """
Test Rule.
""",
"false_positives": ["Fake."],
"from": "now-9m",
"index": ["winlogbeat-*", "logs-endpoint.events.*", "logs-windows.sysmon_operational-*"],
"language": "eql",
"license": "Elastic License v2",
"name": "Fake Test Rule",
"references": [
"https://example.com",
],
"risk_score": 47,
"rule_id": "4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
"severity": "medium",
"tags": [
"Domain: Endpoint",
"OS: Windows",
"Use Case: Threat Detection",
"Tactic: Execution",
"Data Source: Elastic Defend",
"Data Source: Sysmon",
],
"type": "eql",
"query": """
sequence by host.id, process.entity_id with maxspan = 5s
[network where destination.ip in ("127.0.0.1", "::1")]
""",
},
}
expected_error_message = r"Error in both stack and integrations checks:.*Unable to compare ip to string.*"
with self.assertRaisesRegex(ValueError, expected_error_message):
rc.load_dict(eql_rule)
# Change to appropriate destination.address field
eql_rule["rule"][
"query"
] = """
sequence by host.id, process.entity_id with maxspan = 10s
[network where destination.address in ("192.168.1.1", "::1")]
"""
rc.load_dict(eql_rule)
+73 -47
View File
@@ -13,7 +13,9 @@ from semver import Version
import kql
from detection_rules.integrations import (
find_latest_compatible_version, load_integrations_manifests, load_integrations_schemas
find_latest_compatible_version,
load_integrations_manifests,
load_integrations_schemas,
)
from detection_rules.misc import load_current_package_version
from detection_rules.packaging import current_stack_version
@@ -23,31 +25,34 @@ from detection_rules.schemas import get_stack_schemas
from detection_rules.utils import get_path, load_rule_contents
from .base import BaseRuleTest
PACKAGE_STACK_VERSION = Version.parse(current_stack_version(), optional_minor_and_patch=True)
class TestEndpointQuery(BaseRuleTest):
"""Test endpoint-specific rules."""
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"),
"Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.3.0"),
"Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0",
)
def test_os_and_platform_in_query(self):
"""Test that all endpoint rules have an os defined and linux includes platform."""
for rule in self.production_rules:
if not rule.contents.data.get('language') in ('eql', 'kuery'):
if not rule.contents.data.get("language") in ("eql", "kuery"):
continue
if rule.path.parent.name not in ('windows', 'macos', 'linux'):
if rule.path.parent.name not in ("windows", "macos", "linux"):
# skip cross-platform for now
continue
ast = rule.contents.data.ast
fields = [str(f) for f in ast if isinstance(f, (kql.ast.Field, eql.ast.Field))]
err_msg = f'{self.rule_str(rule)} missing required field for endpoint rule'
if 'host.os.type' not in fields:
err_msg = f"{self.rule_str(rule)} missing required field for endpoint rule"
if "host.os.type" not in fields:
# Exception for Forwarded Events which contain Windows-only fields.
if rule.path.parent.name == 'windows' and not any(field.startswith('winlog.') for field in fields):
self.assertIn('host.os.type', fields, err_msg)
if rule.path.parent.name == "windows" and not any(field.startswith("winlog.") for field in fields):
self.assertIn("host.os.type", fields, err_msg)
# going to bypass this for now
# if rule.path.parent.name == 'linux':
@@ -58,8 +63,9 @@ class TestEndpointQuery(BaseRuleTest):
class TestNewTerms(BaseRuleTest):
"""Test new term rules."""
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature."
)
def test_history_window_start(self):
"""Test new terms history window start field."""
@@ -67,39 +73,49 @@ class TestNewTerms(BaseRuleTest):
if rule.contents.data.type == "new_terms":
# validate history window start field exists and is correct
assert rule.contents.data.new_terms.history_window_start, \
"new terms field found with no history_window_start field defined"
assert rule.contents.data.new_terms.history_window_start[0].field == "history_window_start", \
f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'"
assert (
rule.contents.data.new_terms.history_window_start
), "new terms field found with no history_window_start field defined"
assert (
rule.contents.data.new_terms.history_window_start[0].field == "history_window_start"
), f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature."
)
def test_new_terms_field_exists(self):
# validate new terms and history window start fields are correct
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
assert rule.contents.data.new_terms.field == "new_terms_fields", \
f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type"
assert (
rule.contents.data.new_terms.field == "new_terms_fields"
), f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature."
)
def test_new_terms_fields(self):
"""Test new terms fields are schema validated."""
# ecs validation
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
meta = rule.contents.metadata
feature_min_stack = Version.parse('8.4.0')
feature_min_stack = Version.parse("8.4.0")
current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True)
min_stack_version = Version.parse(meta.get("min_stack_version")) if \
meta.get("min_stack_version") else None
min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \
current_package_version else min_stack_version
min_stack_version = (
Version.parse(meta.get("min_stack_version")) if meta.get("min_stack_version") else None
)
min_stack_version = (
current_package_version
if min_stack_version is None or min_stack_version < current_package_version
else min_stack_version
)
assert min_stack_version >= feature_min_stack, \
f"New Terms rule types only compatible with {feature_min_stack}+"
ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs']
beats_version = get_stack_schemas()[str(min_stack_version)]['beats']
assert (
min_stack_version >= feature_min_stack
), f"New Terms rule types only compatible with {feature_min_stack}+"
ecs_version = get_stack_schemas()[str(min_stack_version)]["ecs"]
beats_version = get_stack_schemas()[str(min_stack_version)]["beats"]
# checks if new terms field(s) are in ecs, beats non-ecs or integration schemas
queryvalidator = QueryValidator(rule.contents.data.query)
@@ -113,43 +129,53 @@ class TestNewTerms(BaseRuleTest):
package=tag,
integration="",
rule_stack_version=min_stack_version,
packages_manifest=integration_manifests)
packages_manifest=integration_manifests,
)
if latest_tag_compat_ver:
integration_schema = integration_schemas[tag][latest_tag_compat_ver]
for policy_template in integration_schema.keys():
schema.update(**integration_schemas[tag][latest_tag_compat_ver][policy_template])
for new_terms_field in rule.contents.data.new_terms.value:
assert new_terms_field in schema.keys(), \
f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas"
assert (
new_terms_field in schema.keys()
), f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature."
)
def test_new_terms_max_limit(self):
"""Test new terms max limit."""
# validates length of new_terms to stack version - https://github.com/elastic/kibana/issues/142862
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
meta = rule.contents.metadata
feature_min_stack = Version.parse('8.4.0')
feature_min_stack_extended_fields = Version.parse('8.6.0')
feature_min_stack = Version.parse("8.4.0")
feature_min_stack_extended_fields = Version.parse("8.6.0")
current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True)
min_stack_version = Version.parse(meta.get("min_stack_version")) if \
meta.get("min_stack_version") else None
min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \
current_package_version else min_stack_version
min_stack_version = (
Version.parse(meta.get("min_stack_version")) if meta.get("min_stack_version") else None
)
min_stack_version = (
current_package_version
if min_stack_version is None or min_stack_version < current_package_version
else min_stack_version
)
if feature_min_stack <= min_stack_version < feature_min_stack_extended_fields:
assert len(rule.contents.data.new_terms.value) == 1, \
f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}"
assert (
len(rule.contents.data.new_terms.value) == 1
), f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.6.0"), "Test only applicable to 8.4+ stacks for new terms feature."
)
def test_new_terms_fields_unique(self):
"""Test new terms fields are unique."""
# validate fields are unique
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
assert len(set(rule.contents.data.new_terms.value)) == len(rule.contents.data.new_terms.value), \
f"new terms fields values are not unique - {rule.contents.data.new_terms.value}"
assert len(set(rule.contents.data.new_terms.value)) == len(
rule.contents.data.new_terms.value
), f"new terms fields values are not unique - {rule.contents.data.new_terms.value}"
class TestESQLRules(BaseRuleTest):