[Bug] Add unit tests and fix Alert Suppression schema validation for ThresholdQueryRuleData (#5196)

* Add schema validation for AlertSuppressionMapping

* Add support for indicator match alert suppression

* Add unit tests

* Update order and remove validates_schema method

* Add comments

* Add test for query rule duration only
This commit is contained in:
Eric Forte
2025-10-09 16:21:21 -04:00
committed by GitHub
parent ebb7bb5bce
commit a5c100a65b
3 changed files with 231 additions and 8 deletions
+9 -2
View File
@@ -1074,10 +1074,17 @@ class ThreatMatchRuleData(QueryRuleData):
# All of the possible rule types
# Sort inverse of any inheritance - see comment in TOMLRuleContents.to_dict
# ThresholdQueryRuleData needs to be first in this union to handle cases where there is ambiguity between
# ThresholdAlertSuppression and AlertSuppressionMapping. Since AlertSuppressionMapping has duration as an
# optional field, ThresholdAlertSuppression objects can be mistakenly loaded as an AlertSuppressionMapping
# object with group_by and missing_fields_strategy as missing parameters, resulting in an error.
# Checking the type against ThresholdQueryRuleData first in the union prevent this from occurring.
# Please also keep issue 1141 in mind when handling union schemas.
AnyRuleData = (
EQLRuleData
ThresholdQueryRuleData
| EQLRuleData
| ESQLRuleData
| ThresholdQueryRuleData
| ThreatMatchRuleData
| MachineLearningRuleData
| QueryRuleData
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "detection_rules"
version = "1.4.11"
version = "1.4.12"
description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Securitys Detection Engine."
readme = "README.md"
requires-python = ">=3.12"
+221 -5
View File
@@ -3,7 +3,10 @@
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
from typing import Any
import eql
from marshmallow import ValidationError
from detection_rules.rule_loader import RuleCollection
@@ -22,26 +25,46 @@ def mk_metadata(integrations: list[str], comments: str = "Test metadata") -> dic
}
def mk_rule(
def mk_rule( # noqa: PLR0913
*,
name: str,
rule_id: str,
description: str,
risk_score: int,
query: str,
) -> dict:
language: str = "eql",
query_type: str = "eql",
threshold: dict[str, Any] | None = None,
alert_suppression: dict[str, Any] | None = None,
index: list[str] | None = None,
threat_language: str | None = None,
threat_index: list[str] | None = None,
threat_indicator_path: str | None = None,
threat_mapping: list[Any] | None = None,
) -> dict[str, Any]:
"""Create rule dictionary."""
return {
rule = {
"author": ["Elastic"],
"description": description,
"language": "eql",
"language": language,
"name": name,
"risk_score": risk_score,
"rule_id": rule_id,
"severity": "low",
"type": "eql",
"type": query_type,
"query": query,
"alert_suppression": alert_suppression,
}
if threshold is not None:
rule["threshold"] = threshold
if query_type == "threat_match":
rule["index"] = index
rule["threat_language"] = threat_language
rule["threat_index"] = threat_index
rule["threat_indicator_path"] = threat_indicator_path
rule["threat_mapping"] = threat_mapping
return rule
class TestEQLInSet(BaseRuleTest):
@@ -283,3 +306,196 @@ class TestEQLSequencePerIntegration(BaseRuleTest):
),
}
rc.load_dict(rule)
class TestAlertSuppressionValidation(BaseRuleTest):
"""Tests for alert_suppression field validation in rules."""
def test_threshold_rule_duration(self) -> None:
"""Test that a threshold rule with alert_suppression with just duration validates correctly."""
rc = RuleCollection()
query = """
process.name: \"test\"
"""
rule_dict: dict[str, Any] = {
"metadata": mk_metadata(
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
),
"rule": mk_rule(
name="Fake Test Rule",
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
description="Test Rule.",
risk_score=47,
query=query,
language="kuery",
query_type="threshold",
threshold={"field": [], "value": 200, "cardinality": []},
alert_suppression={"duration": {"value": 5, "unit": "h"}},
),
}
_ = rc.load_dict(rule_dict)
def test_query_rule_duration(self) -> None:
"""Test that a query rule with alert_suppression with group_by and missing_fields_strategy validates correctly."""
rc = RuleCollection()
query = """
process.name: \"test\"
"""
rule_dict: dict[str, Any] = {
"metadata": mk_metadata(
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
),
"rule": mk_rule(
name="Fake Test Rule",
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
description="Test Rule.",
risk_score=47,
query=query,
language="kuery",
query_type="query",
threshold=None,
alert_suppression={"duration": {"value": 5, "unit": "h"}},
),
}
with self.assertRaises((ValidationError, TypeError)):
_ = rc.load_dict(rule_dict)
def test_query_rule_group_by_missing_fields(self) -> None:
"""Test that a query rule with alert_suppression with group_by and missing_fields_strategy validates correctly."""
rc = RuleCollection()
query = """
process.name: \"test\"
"""
rule_dict: dict[str, Any] = {
"metadata": mk_metadata(
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
),
"rule": mk_rule(
name="Fake Test Rule",
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
description="Test Rule.",
risk_score=47,
query=query,
language="kuery",
query_type="query",
threshold=None,
alert_suppression={"group_by": ["process.id"], "missing_fields_strategy": "suppress"},
),
}
_ = rc.load_dict(rule_dict)
def test_query_rule_group_by(self) -> None:
"""Test that a query rule with alert_suppression with just group_by is not valid."""
rc = RuleCollection()
query = """
process.name: \"test\"
"""
rule_dict: dict[str, Any] = {
"metadata": mk_metadata(
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
),
"rule": mk_rule(
name="Fake Test Rule",
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
description="Test Rule.",
risk_score=47,
query=query,
language="kuery",
query_type="query",
threshold=None,
alert_suppression={"group_by": ["process.id"]},
),
}
with self.assertRaises((ValidationError, TypeError)):
_ = rc.load_dict(rule_dict)
def test_query_rule_missing_fields_strategy(self) -> None:
"""Test that a query rule with alert_suppression with just missing_fields_strategy is not valid."""
rc = RuleCollection()
query = """
process.name: \"test\"
"""
rule_dict: dict[str, Any] = {
"metadata": mk_metadata(
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
),
"rule": mk_rule(
name="Fake Test Rule",
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
description="Test Rule.",
risk_score=47,
query=query,
language="kuery",
query_type="query",
threshold=None,
alert_suppression={"missing_fields_strategy": "suppress"},
),
}
with self.assertRaises((ValidationError, TypeError)):
_ = rc.load_dict(rule_dict)
def test_threat_match_rule(self) -> None:
"""Test that a threat_match rule with alert_suppression with all fields set is valid."""
rc = RuleCollection()
query = """
process.name: \"test\"
"""
rule_dict: dict[str, Any] = {
"metadata": mk_metadata(
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
),
"rule": mk_rule(
name="Fake Test Rule",
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
description="Test Rule.",
risk_score=47,
query=query,
language="kuery",
query_type="threat_match",
threshold=None,
alert_suppression={
"group_by": ["client.ip"],
"duration": {"value": 12, "unit": "h"},
"missing_fields_strategy": "suppress",
},
index=["logs-*"],
threat_language="kuery",
threat_index=["logs-*"],
threat_indicator_path="threat.indicator",
threat_mapping=[{"entries": [{"field": "client.ip", "type": "mapping", "value": "client.ip"}]}],
),
}
_ = rc.load_dict(rule_dict)
def test_threat_match_rule_missing_fields_duration(self) -> None:
"""Test that a threat_match rule with alert_suppression with missing_fields_strategy and duration is not valid."""
rc = RuleCollection()
query = """
process.name: \"test\"
"""
rule_dict: dict[str, Any] = {
"metadata": mk_metadata(
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
),
"rule": mk_rule(
name="Fake Test Rule",
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
description="Test Rule.",
risk_score=47,
query=query,
language="kuery",
query_type="threat_match",
threshold=None,
alert_suppression={
"duration": {"value": 12, "unit": "h"},
"missing_fields_strategy": "suppress",
},
index=["logs-*"],
threat_language="kuery",
threat_index=["logs-*"],
threat_indicator_path="threat.indicator",
threat_mapping=[{"entries": [{"field": "client.ip", "type": "mapping", "value": "client.ip"}]}],
),
}
with self.assertRaises((ValidationError, TypeError)):
_ = rc.load_dict(rule_dict)