diff --git a/detection_rules/mixins.py b/detection_rules/mixins.py index d00820db6..cd448ec3f 100644 --- a/detection_rules/mixins.py +++ b/detection_rules/mixins.py @@ -6,16 +6,19 @@ """Generic mixin classes.""" from pathlib import Path -from typing import TypeVar, Type, Optional, Any +from typing import Any, Optional, TypeVar, Type import json import marshmallow_dataclass import marshmallow_dataclass.union_field import marshmallow_jsonschema import marshmallow_union -from marshmallow import Schema, ValidationError, fields +from marshmallow import Schema, ValidationError, fields, validates_schema +from .misc import load_current_package_version from .schemas import definitions +from .schemas.stack_compat import get_incompatible_fields +from .semver import Version from .utils import cached, dict_hash T = TypeVar('T') @@ -171,6 +174,26 @@ class LockDataclassMixin: path.write_text(json.dumps(contents, indent=2, sort_keys=True)) +class StackCompatMixin: + """Mixin to restrict schema compatibility to defined stack versions.""" + + @validates_schema + def validate_field_compatibility(self, data: dict, **kwargs): + """Verify stack-specific fields are properly applied to schema.""" + package_version = Version(load_current_package_version()) + schema_fields = getattr(self, 'fields', {}) + incompatible = get_incompatible_fields(list(schema_fields.values()), package_version) + if not incompatible: + return + + package_version = load_current_package_version() + for field, bounds in incompatible.items(): + min_compat, max_compat = bounds + if data.get(field) is not None: + raise ValidationError(f'Invalid field: "{field}" for stack version: {package_version}, ' + f'min compatibility: {min_compat}, max compatibility: {max_compat}') + + class PatchedJSONSchema(marshmallow_jsonschema.JSONSchema): # Patch marshmallow-jsonschema to support marshmallow-dataclass[union] diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 471f73d10..f24668686 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -11,7 +11,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import cached_property from pathlib import Path -from typing import Literal, Union, Optional, List, Any, Dict +from typing import Literal, Union, Optional, List, Any, Dict, Tuple from uuid import uuid4 import eql @@ -19,9 +19,11 @@ from marshmallow import ValidationError, validates_schema import kql from . import utils -from .mixins import MarshmallowDataclassMixin +from .mixins import MarshmallowDataclassMixin, StackCompatMixin from .rule_formatter import toml_write, nested_normalize from .schemas import SCHEMA_DIR, definitions, downgrade, get_stack_schemas, get_min_supported_stack_version +from .schemas.stack_compat import get_restricted_fields +from .semver import Version from .utils import cached _META_SCHEMA_REQ_DEFAULTS = {} @@ -146,7 +148,7 @@ class FlatThreatMapping(MarshmallowDataclassMixin): @dataclass(frozen=True) -class BaseRuleData(MarshmallowDataclassMixin): +class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin): actions: Optional[list] author: List[str] building_block_type: Optional[str] @@ -168,10 +170,13 @@ class BaseRuleData(MarshmallowDataclassMixin): # explicitly NOT allowed! # output_index: Optional[str] references: Optional[List[str]] + related_integrations: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3"))) + required_fields: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3"))) risk_score: definitions.RiskScore risk_score_mapping: Optional[List[RiskScoreMapping]] rule_id: definitions.UUIDString rule_name_override: Optional[str] + setup: Optional[str] = field(metadata=dict(metadata=dict(min_compat="8.3"))) severity_mapping: Optional[List[SeverityMapping]] severity: definitions.Severity tags: Optional[List[str]] @@ -186,7 +191,7 @@ class BaseRuleData(MarshmallowDataclassMixin): @classmethod def save_schema(cls): """Save the schema as a jsonschema.""" - fields: List[dataclasses.Field] = dataclasses.fields(cls) + fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(cls) type_field = next(f for f in fields if f.name == "type") rule_type = typing.get_args(type_field.type)[0] if cls != BaseRuleData else "base" schema = cls.jsonschema() @@ -200,6 +205,12 @@ class BaseRuleData(MarshmallowDataclassMixin): def validate_query(self, meta: RuleMeta) -> None: pass + @cached_property + def get_restricted_fields(self) -> Optional[Dict[str, tuple]]: + """Get stack version restricted fields.""" + fields: List[dataclasses.Field, ...] = list(dataclasses.fields(self)) + return get_restricted_fields(fields) + @dataclass class QueryValidator: @@ -536,6 +547,24 @@ class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin): return converted + def check_restricted_fields_compatibility(self) -> Dict[str, dict]: + """Check for compatibility between restricted fields and the min_stack_version of the rule.""" + default_min_stack = get_min_supported_stack_version(drop_patch=True) + if self.metadata.min_stack_version is not None: + min_stack = Version(self.metadata.min_stack_version) + else: + min_stack = default_min_stack + restricted = self.data.get_restricted_fields + + invalid = {} + for _field, values in restricted.items(): + if self.data.get(_field) is not None: + min_allowed, _ = values + if min_stack < min_allowed: + invalid[_field] = {'min_stack_version': min_stack, 'min_allowed_version': min_allowed} + + return invalid + @dataclass class TOMLRule: diff --git a/detection_rules/schemas/__init__.py b/detection_rules/schemas/__init__.py index 51293ec7d..377237a18 100644 --- a/detection_rules/schemas/__init__.py +++ b/detection_rules/schemas/__init__.py @@ -268,6 +268,7 @@ def get_stack_versions(drop_patch=False) -> List[str]: return versions +@cached def get_min_supported_stack_version(drop_patch=False) -> Version: """Get the minimum defined and supported stack version.""" stack_map = load_stack_schema_map() diff --git a/detection_rules/schemas/stack_compat.py b/detection_rules/schemas/stack_compat.py new file mode 100644 index 000000000..a2c274e2e --- /dev/null +++ b/detection_rules/schemas/stack_compat.py @@ -0,0 +1,51 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. + +from dataclasses import Field +from typing import Dict, List, Optional, Tuple + +from ..misc import cached +from ..semver import Version + + +@cached +def get_restricted_field(schema_field: Field) -> Tuple[Optional[Version], Optional[Version]]: + """Get an optional min and max compatible versions of a field (from a schema or dataclass).""" + # nested get is to support schema fields being passed directly from dataclass or fields in schema class, since + # marshmallow_dataclass passes the embedded metadata directly + min_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('min_compat') + max_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('max_compat') + min_compat = Version(min_compat) if min_compat else None + max_compat = Version(max_compat) if max_compat else None + return min_compat, max_compat + + +@cached +def get_restricted_fields(schema_fields: List[Field]) -> Dict[str, Tuple[Optional[Version], Optional[Version]]]: + """Get a list of optional min and max compatible versions of fields (from a schema or dataclass).""" + restricted = {} + for _field in schema_fields: + min_compat, max_compat = get_restricted_field(_field) + if min_compat or max_compat: + restricted[_field.name] = (min_compat, max_compat) + + return restricted + + +@cached +def get_incompatible_fields(schema_fields: List[Field], package_version: Version) -> Optional[Dict[str, tuple]]: + """Get a list of fields that are incompatible with the package version.""" + if not schema_fields: + return + + incompatible = {} + restricted_fields = get_restricted_fields(schema_fields) + for field_name, values in restricted_fields.items(): + min_compat, max_compat = values + + if min_compat and package_version < min_compat or max_compat and package_version > max_compat: + incompatible[field_name] = (min_compat, max_compat) + + return incompatible diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index 2170608aa..3e781dce8 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -677,3 +677,22 @@ class TestIntegrationRules(BaseRuleTest): self.fail(f'{self.rule_str(rule)} expected {integration} config missing\n\n' f'Expected: {note_str}\n\n' f'Actual: {rule.contents.data.note}') + + +class TestIncompatibleFields(BaseRuleTest): + """Test stack restricted fields do not backport beyond allowable limits.""" + + def test_rule_backports_for_restricted_fields(self): + """Test that stack restricted fields will not backport to older rule versions.""" + invalid_rules = [] + + for rule in self.all_rules: + invalid = rule.contents.check_restricted_fields_compatibility() + if invalid: + invalid_rules.append(f'{self.rule_str(rule)} {invalid}') + + if invalid_rules: + invalid_str = '\n'.join(invalid_rules) + err_msg = 'The following rules have min_stack_versions lower than allowed for restricted fields:\n' + err_msg += invalid_str + self.fail(err_msg)