From bde05d63c65999d83c41bb2b3e00f9726d475536 Mon Sep 17 00:00:00 2001 From: Mika Ayenson Date: Mon, 12 Feb 2024 09:55:46 -0600 Subject: [PATCH] [FR] Add support for Threshold Alert Suppression (#3433) (cherry picked from commit c3ca01ebcc40ed2806d236177e6657238d2c18a1) --- detection_rules/rule.py | 26 ++++++++++++++++++-------- detection_rules/schemas/definitions.py | 6 +++--- tests/test_all_rules.py | 22 ++++++++-------------- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 456363a79..9a8916a0a 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -215,20 +215,29 @@ class FlatThreatMapping(MarshmallowDataclassMixin): sub_technique_ids: List[str] +@dataclass(frozen=True) +class AlertSuppressionDuration: + """Mapping to alert suppression duration.""" + unit: definitions.TimeUnits + value: definitions.AlertSuppressionValue + + @dataclass(frozen=True) class AlertSuppressionMapping(MarshmallowDataclassMixin, StackCompatMixin): """Mapping to alert suppression.""" - @dataclass - class AlertSuppressionDuration: - """Mapping to allert suppression duration.""" - unit: definitions.TimeUnits - value: int - group_by: List[definitions.NonEmptyStr] + group_by: definitions.AlertSuppressionGroupBy duration: Optional[AlertSuppressionDuration] missing_fields_strategy: definitions.AlertSuppressionMissing +@dataclass(frozen=True) +class ThresholdAlertSuppression: + """Mapping to alert suppression.""" + + duration: AlertSuppressionDuration + + @dataclass(frozen=True) class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin): @dataclass @@ -612,8 +621,8 @@ class QueryRuleData(BaseRuleData): def validates_query_data(self, data, **kwargs): """Custom validation for query rule type and subclasses.""" # alert suppression is only valid for query rule type and not any of its subclasses - if data.get('alert_suppression') and data['type'] != 'query': - raise ValidationError("Alert suppression is only valid for query rule type.") + if data.get('alert_suppression') and data['type'] not in ('query', 'threshold'): + raise ValidationError("Alert suppression is only valid for query and threshold rule types.") @dataclass(frozen=True) @@ -641,6 +650,7 @@ class ThresholdQueryRuleData(QueryRuleData): type: Literal["threshold"] threshold: ThresholdMapping + alert_suppression: Optional[ThresholdAlertSuppression] = field(metadata=dict(metadata=dict(min_compat="8.12"))) @dataclass(frozen=True) diff --git a/detection_rules/schemas/definitions.py b/detection_rules/schemas/definitions.py index 38e4d2822..38aa943fa 100644 --- a/detection_rules/schemas/definitions.py +++ b/detection_rules/schemas/definitions.py @@ -126,12 +126,12 @@ EXPECTED_RULE_TAGS = [ 'Use Case: UEBA', 'Use Case: Vulnerability' ] - +NonEmptyStr = NewType('NonEmptyStr', str, validate=validate.Length(min=1)) MACHINE_LEARNING_PACKAGES = ['LMD', 'DGA', 'DED', 'ProblemChild', 'Beaconing'] - +AlertSuppressionGroupBy = NewType('AlertSuppressionGroupBy', List[NonEmptyStr], validate=validate.Length(min=1, max=3)) AlertSuppressionMissing = NewType('AlertSuppressionMissing', str, validate=validate.OneOf(['suppress', 'doNotSuppress'])) -NonEmptyStr = NewType('NonEmptyStr', str, validate=validate.Length(min=1)) +AlertSuppressionValue = NewType("AlertSupressionValue", int, validate=validate.Range(min=1)) TimeUnits = Literal['s', 'm', 'h'] BranchVer = NewType('BranchVer', str, validate=validate.Regexp(BRANCH_PATTERN)) CardinalityFields = NewType('CardinalityFields', List[NonEmptyStr], validate=validate.Length(min=0, max=3)) diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index 6b6dfda35..e30c8b7de 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -24,8 +24,8 @@ from detection_rules.integrations import (find_latest_compatible_version, load_integrations_schemas) from detection_rules.misc import load_current_package_version from detection_rules.packaging import current_stack_version -from detection_rules.rule import (QueryRuleData, QueryValidator, - TOMLRuleContents) +from detection_rules.rule import (AlertSuppressionMapping, QueryRuleData, QueryValidator, + ThresholdAlertSuppression, TOMLRuleContents) from detection_rules.rule_loader import FILE_PATTERN from detection_rules.rule_validators import EQLValidator, KQLValidator from detection_rules.schemas import definitions, get_stack_schemas @@ -1215,23 +1215,17 @@ class TestNoteMarkdownPlugins(BaseRuleTest): class TestAlertSuppression(BaseRuleTest): """Test rule alert suppression.""" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.8.0"), - "Test only applicable to 8.6+ stacks for rule alert suppression feature.") - def test_group_length(self): - """Test to ensure the rule alert suppression group_by does not exceed 3 elements.""" - for rule in self.production_rules: - if rule.contents.data.get('alert_suppression'): - group_length = len(rule.contents.data.alert_suppression.group_by) - if group_length > 3: - self.fail(f'{self.rule_str(rule)} has rule alert suppression with more than 3 elements.') - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.8.0"), "Test only applicable to 8.8+ stacks for rule alert suppression feature.") def test_group_field_in_schemas(self): """Test to ensure the fields are defined is in ECS/Beats/Integrations schema.""" for rule in self.production_rules: - if rule.contents.data.get('alert_suppression'): - group_by_fields = rule.contents.data.alert_suppression.group_by + rule_type = rule.contents.data.get('type') + if rule_type in ('query', 'threshold') and rule.contents.data.get('alert_suppression'): + if isinstance(rule.contents.data.alert_suppression, AlertSuppressionMapping): + group_by_fields = rule.contents.data.alert_suppression.group_by + elif isinstance(rule.contents.data.alert_suppression, ThresholdAlertSuppression): + group_by_fields = rule.contents.data.threshold.field min_stack_version = rule.contents.metadata.get("min_stack_version") if min_stack_version is None: min_stack_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True)