diff --git a/detection_rules/rule.py b/detection_rules/rule.py index e27b004c8..a24945d12 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -229,7 +229,8 @@ class AlertSuppressionMapping(MarshmallowDataclassMixin, StackCompatMixin): value: int group_by: List[definitions.NonEmptyStr] - duration: Optional[AlertSuppressionDuration] = field(metadata=dict(metadata=dict(min_compat="8.7"))) + duration: Optional[AlertSuppressionDuration] + missing_fields_strategy: definitions.AlertSuppressionMissing @dataclass(frozen=True) @@ -247,7 +248,6 @@ class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin): integration: Optional[definitions.NonEmptyStr] actions: Optional[list] - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.6"))) author: List[str] building_block_type: Optional[definitions.BuildingBlockType] description: str @@ -561,6 +561,7 @@ class QueryRuleData(BaseRuleData): index: Optional[List[str]] query: str language: definitions.FilterLanguages + alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.8"))) @cached_property def validator(self) -> Optional[QueryValidator]: @@ -592,6 +593,14 @@ class QueryRuleData(BaseRuleData): if validator is not None: return validator.get_required_fields(index or []) + @validates_schema + def validate_exceptions(self, data, **kwargs): + """Custom validation for query rule type and subclasses.""" + + # alert suppression is only valid for query rule type and not any of its subclasses + if data.get('alert_suppression') and data['type'] != 'query': + raise ValidationError("Alert suppression is only valid for query rule type.") + @dataclass(frozen=True) class MachineLearningRuleData(BaseRuleData): diff --git a/detection_rules/schemas/definitions.py b/detection_rules/schemas/definitions.py index 264b4b960..5f53e08fd 100644 --- a/detection_rules/schemas/definitions.py +++ b/detection_rules/schemas/definitions.py @@ -128,6 +128,8 @@ EXPECTED_RULE_TAGS = [ MACHINE_LEARNING_PACKAGES = ['LMD', 'DGA', 'DED', 'ProblemChild', 'Beaconing'] +AlertSuppressionMissing = NewType('AlertSuppressionMissing', str, + validate=validate.OneOf(['suppress', 'doNotSuppress'])) NonEmptyStr = NewType('NonEmptyStr', str, validate=validate.Length(min=1)) TimeUnits = Literal['s', 'm', 'h'] BranchVer = NewType('BranchVer', str, validate=validate.Regexp(BRANCH_PATTERN)) diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index ec0a2c79e..0c4e3c6ca 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -1276,22 +1276,22 @@ class TestNoteMarkdownPlugins(BaseRuleTest): class TestAlertSuppression(BaseRuleTest): """Test rule alert suppression.""" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"), + @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.8.0"), "Test only applicable to 8.6+ stacks for rule alert suppression feature.") def test_group_length(self): """Test to ensure the rule alert suppression group_by does not exceed 3 elements.""" for rule in self.production_rules: - if rule.contents.data.alert_suppression: + if rule.contents.data.get('alert_suppression'): group_length = len(rule.contents.data.alert_suppression.group_by) if group_length > 3: self.fail(f'{self.rule_str(rule)} has rule alert suppression with more than 3 elements.') - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"), - "Test only applicable to 8.6+ stacks for rule alert suppression feature.") + @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.8.0"), + "Test only applicable to 8.8+ stacks for rule alert suppression feature.") def test_group_field_in_schemas(self): """Test to ensure the fields are defined is in ECS/Beats/Integrations schema.""" for rule in self.production_rules: - if rule.contents.data.alert_suppression: + if rule.contents.data.get('alert_suppression'): group_by_fields = rule.contents.data.alert_suppression.group_by min_stack_version = rule.contents.metadata.get("min_stack_version") if min_stack_version is None: @@ -1316,33 +1316,3 @@ class TestAlertSuppression(BaseRuleTest): if fld not in schema.keys(): self.fail(f"{self.rule_str(rule)} alert suppression field {fld} not \ found in ECS, Beats, or non-ecs schemas") - - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"), - "Test only applicable to 8.6+ stacks for rule alert suppression feature.") - def test_stack_version(self): - """Test to ensure the stack version is 8.6+""" - for rule in self.production_rules: - if rule.contents.data.alert_suppression: - per_time = rule.contents.data.alert_suppression.get("duration", None) - min_stack_version = rule.contents.metadata.get("min_stack_version") - if min_stack_version is None: - min_stack_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - else: - min_stack_version = Version.parse(min_stack_version) - if not per_time and min_stack_version < Version.parse("8.6.0"): - self.fail(f'{self.rule_str(rule)} has rule alert suppression but \ - min_stack is not 8.6+') - elif per_time and min_stack_version < Version.parse("8.7.0"): - self.fail(f'{self.rule_str(rule)} has rule alert suppression with \ - per time but min_stack is not 8.7+') - - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"), - "Test only applicable to 8.6+ stacks for rule alert suppression feature.") - def test_query_type(self): - """Test to ensure the query type is KQL only.""" - for rule in self.production_rules: - if rule.contents.data.alert_suppression: - rule_type = rule.contents.data.language - if rule_type != 'kuery': - self.fail(f'{self.rule_str(rule)} has rule alert suppression with \ - but query language is not KQL')