From a5c100a65ba5bdcb3598b0a567e427bf20c9fe7c Mon Sep 17 00:00:00 2001 From: Eric Forte <119343520+eric-forte-elastic@users.noreply.github.com> Date: Thu, 9 Oct 2025 16:21:21 -0400 Subject: [PATCH] [Bug] Add unit tests and fix Alert Suppression schema validation for ThresholdQueryRuleData (#5196) * Add schema validation for AlertSuppressionMapping * Add support for indicator match alert suppression * Add unit tests * Update order and remove validates_schema method * Add comments * Add test for query rule duration only --- detection_rules/rule.py | 11 +- pyproject.toml | 2 +- tests/test_python_library.py | 226 ++++++++++++++++++++++++++++++++++- 3 files changed, 231 insertions(+), 8 deletions(-) diff --git a/detection_rules/rule.py b/detection_rules/rule.py index e80cd6848..f15856969 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -1074,10 +1074,17 @@ class ThreatMatchRuleData(QueryRuleData): # All of the possible rule types # Sort inverse of any inheritance - see comment in TOMLRuleContents.to_dict +# ThresholdQueryRuleData needs to be first in this union to handle cases where there is ambiguity between +# ThresholdAlertSuppression and AlertSuppressionMapping. Since AlertSuppressionMapping has duration as an +# optional field, ThresholdAlertSuppression objects can be mistakenly loaded as an AlertSuppressionMapping +# object with group_by and missing_fields_strategy as missing parameters, resulting in an error. +# Checking the type against ThresholdQueryRuleData first in the union prevent this from occurring. +# Please also keep issue 1141 in mind when handling union schemas. + AnyRuleData = ( - EQLRuleData + ThresholdQueryRuleData + | EQLRuleData | ESQLRuleData - | ThresholdQueryRuleData | ThreatMatchRuleData | MachineLearningRuleData | QueryRuleData diff --git a/pyproject.toml b/pyproject.toml index d8d1747fc..83e3f57ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "detection_rules" -version = "1.4.11" +version = "1.4.12" description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine." readme = "README.md" requires-python = ">=3.12" diff --git a/tests/test_python_library.py b/tests/test_python_library.py index b62a9e3b5..97c3ad05b 100644 --- a/tests/test_python_library.py +++ b/tests/test_python_library.py @@ -3,7 +3,10 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. +from typing import Any + import eql +from marshmallow import ValidationError from detection_rules.rule_loader import RuleCollection @@ -22,26 +25,46 @@ def mk_metadata(integrations: list[str], comments: str = "Test metadata") -> dic } -def mk_rule( +def mk_rule( # noqa: PLR0913 *, name: str, rule_id: str, description: str, risk_score: int, query: str, -) -> dict: + language: str = "eql", + query_type: str = "eql", + threshold: dict[str, Any] | None = None, + alert_suppression: dict[str, Any] | None = None, + index: list[str] | None = None, + threat_language: str | None = None, + threat_index: list[str] | None = None, + threat_indicator_path: str | None = None, + threat_mapping: list[Any] | None = None, +) -> dict[str, Any]: """Create rule dictionary.""" - return { + rule = { "author": ["Elastic"], "description": description, - "language": "eql", + "language": language, "name": name, "risk_score": risk_score, "rule_id": rule_id, "severity": "low", - "type": "eql", + "type": query_type, "query": query, + "alert_suppression": alert_suppression, } + if threshold is not None: + rule["threshold"] = threshold + if query_type == "threat_match": + rule["index"] = index + rule["threat_language"] = threat_language + rule["threat_index"] = threat_index + rule["threat_indicator_path"] = threat_indicator_path + rule["threat_mapping"] = threat_mapping + + return rule class TestEQLInSet(BaseRuleTest): @@ -283,3 +306,196 @@ class TestEQLSequencePerIntegration(BaseRuleTest): ), } rc.load_dict(rule) + + +class TestAlertSuppressionValidation(BaseRuleTest): + """Tests for alert_suppression field validation in rules.""" + + def test_threshold_rule_duration(self) -> None: + """Test that a threshold rule with alert_suppression with just duration validates correctly.""" + rc = RuleCollection() + query = """ + process.name: \"test\" + """ + rule_dict: dict[str, Any] = { + "metadata": mk_metadata( + ["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup" + ), + "rule": mk_rule( + name="Fake Test Rule", + rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3", + description="Test Rule.", + risk_score=47, + query=query, + language="kuery", + query_type="threshold", + threshold={"field": [], "value": 200, "cardinality": []}, + alert_suppression={"duration": {"value": 5, "unit": "h"}}, + ), + } + _ = rc.load_dict(rule_dict) + + def test_query_rule_duration(self) -> None: + """Test that a query rule with alert_suppression with group_by and missing_fields_strategy validates correctly.""" + rc = RuleCollection() + query = """ + process.name: \"test\" + """ + rule_dict: dict[str, Any] = { + "metadata": mk_metadata( + ["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup" + ), + "rule": mk_rule( + name="Fake Test Rule", + rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3", + description="Test Rule.", + risk_score=47, + query=query, + language="kuery", + query_type="query", + threshold=None, + alert_suppression={"duration": {"value": 5, "unit": "h"}}, + ), + } + with self.assertRaises((ValidationError, TypeError)): + _ = rc.load_dict(rule_dict) + + def test_query_rule_group_by_missing_fields(self) -> None: + """Test that a query rule with alert_suppression with group_by and missing_fields_strategy validates correctly.""" + rc = RuleCollection() + query = """ + process.name: \"test\" + """ + rule_dict: dict[str, Any] = { + "metadata": mk_metadata( + ["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup" + ), + "rule": mk_rule( + name="Fake Test Rule", + rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3", + description="Test Rule.", + risk_score=47, + query=query, + language="kuery", + query_type="query", + threshold=None, + alert_suppression={"group_by": ["process.id"], "missing_fields_strategy": "suppress"}, + ), + } + _ = rc.load_dict(rule_dict) + + def test_query_rule_group_by(self) -> None: + """Test that a query rule with alert_suppression with just group_by is not valid.""" + rc = RuleCollection() + query = """ + process.name: \"test\" + """ + rule_dict: dict[str, Any] = { + "metadata": mk_metadata( + ["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup" + ), + "rule": mk_rule( + name="Fake Test Rule", + rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3", + description="Test Rule.", + risk_score=47, + query=query, + language="kuery", + query_type="query", + threshold=None, + alert_suppression={"group_by": ["process.id"]}, + ), + } + with self.assertRaises((ValidationError, TypeError)): + _ = rc.load_dict(rule_dict) + + def test_query_rule_missing_fields_strategy(self) -> None: + """Test that a query rule with alert_suppression with just missing_fields_strategy is not valid.""" + rc = RuleCollection() + query = """ + process.name: \"test\" + """ + rule_dict: dict[str, Any] = { + "metadata": mk_metadata( + ["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup" + ), + "rule": mk_rule( + name="Fake Test Rule", + rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3", + description="Test Rule.", + risk_score=47, + query=query, + language="kuery", + query_type="query", + threshold=None, + alert_suppression={"missing_fields_strategy": "suppress"}, + ), + } + with self.assertRaises((ValidationError, TypeError)): + _ = rc.load_dict(rule_dict) + + def test_threat_match_rule(self) -> None: + """Test that a threat_match rule with alert_suppression with all fields set is valid.""" + rc = RuleCollection() + query = """ + process.name: \"test\" + """ + rule_dict: dict[str, Any] = { + "metadata": mk_metadata( + ["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup" + ), + "rule": mk_rule( + name="Fake Test Rule", + rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3", + description="Test Rule.", + risk_score=47, + query=query, + language="kuery", + query_type="threat_match", + threshold=None, + alert_suppression={ + "group_by": ["client.ip"], + "duration": {"value": 12, "unit": "h"}, + "missing_fields_strategy": "suppress", + }, + index=["logs-*"], + threat_language="kuery", + threat_index=["logs-*"], + threat_indicator_path="threat.indicator", + threat_mapping=[{"entries": [{"field": "client.ip", "type": "mapping", "value": "client.ip"}]}], + ), + } + _ = rc.load_dict(rule_dict) + + def test_threat_match_rule_missing_fields_duration(self) -> None: + """Test that a threat_match rule with alert_suppression with missing_fields_strategy and duration is not valid.""" + rc = RuleCollection() + query = """ + process.name: \"test\" + """ + rule_dict: dict[str, Any] = { + "metadata": mk_metadata( + ["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup" + ), + "rule": mk_rule( + name="Fake Test Rule", + rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3", + description="Test Rule.", + risk_score=47, + query=query, + language="kuery", + query_type="threat_match", + threshold=None, + alert_suppression={ + "duration": {"value": 12, "unit": "h"}, + "missing_fields_strategy": "suppress", + }, + index=["logs-*"], + threat_language="kuery", + threat_index=["logs-*"], + threat_indicator_path="threat.indicator", + threat_mapping=[{"entries": [{"field": "client.ip", "type": "mapping", "value": "client.ip"}]}], + ), + } + with self.assertRaises((ValidationError, TypeError)): + _ = rc.load_dict(rule_dict)