Add support for restricted fields (#2053)
* Add support for restricted fields (fields valid only in min/max stack versions) * add test to ensure rule backports wont exceed min compat
This commit is contained in:
@@ -6,16 +6,19 @@
|
||||
"""Generic mixin classes."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TypeVar, Type, Optional, Any
|
||||
from typing import Any, Optional, TypeVar, Type
|
||||
|
||||
import json
|
||||
import marshmallow_dataclass
|
||||
import marshmallow_dataclass.union_field
|
||||
import marshmallow_jsonschema
|
||||
import marshmallow_union
|
||||
from marshmallow import Schema, ValidationError, fields
|
||||
from marshmallow import Schema, ValidationError, fields, validates_schema
|
||||
|
||||
from .misc import load_current_package_version
|
||||
from .schemas import definitions
|
||||
from .schemas.stack_compat import get_incompatible_fields
|
||||
from .semver import Version
|
||||
from .utils import cached, dict_hash
|
||||
|
||||
T = TypeVar('T')
|
||||
@@ -171,6 +174,26 @@ class LockDataclassMixin:
|
||||
path.write_text(json.dumps(contents, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
class StackCompatMixin:
|
||||
"""Mixin to restrict schema compatibility to defined stack versions."""
|
||||
|
||||
@validates_schema
|
||||
def validate_field_compatibility(self, data: dict, **kwargs):
|
||||
"""Verify stack-specific fields are properly applied to schema."""
|
||||
package_version = Version(load_current_package_version())
|
||||
schema_fields = getattr(self, 'fields', {})
|
||||
incompatible = get_incompatible_fields(list(schema_fields.values()), package_version)
|
||||
if not incompatible:
|
||||
return
|
||||
|
||||
package_version = load_current_package_version()
|
||||
for field, bounds in incompatible.items():
|
||||
min_compat, max_compat = bounds
|
||||
if data.get(field) is not None:
|
||||
raise ValidationError(f'Invalid field: "{field}" for stack version: {package_version}, '
|
||||
f'min compatibility: {min_compat}, max compatibility: {max_compat}')
|
||||
|
||||
|
||||
class PatchedJSONSchema(marshmallow_jsonschema.JSONSchema):
|
||||
|
||||
# Patch marshmallow-jsonschema to support marshmallow-dataclass[union]
|
||||
|
||||
+33
-4
@@ -11,7 +11,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union, Optional, List, Any, Dict
|
||||
from typing import Literal, Union, Optional, List, Any, Dict, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import eql
|
||||
@@ -19,9 +19,11 @@ from marshmallow import ValidationError, validates_schema
|
||||
|
||||
import kql
|
||||
from . import utils
|
||||
from .mixins import MarshmallowDataclassMixin
|
||||
from .mixins import MarshmallowDataclassMixin, StackCompatMixin
|
||||
from .rule_formatter import toml_write, nested_normalize
|
||||
from .schemas import SCHEMA_DIR, definitions, downgrade, get_stack_schemas, get_min_supported_stack_version
|
||||
from .schemas.stack_compat import get_restricted_fields
|
||||
from .semver import Version
|
||||
from .utils import cached
|
||||
|
||||
_META_SCHEMA_REQ_DEFAULTS = {}
|
||||
@@ -146,7 +148,7 @@ class FlatThreatMapping(MarshmallowDataclassMixin):
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseRuleData(MarshmallowDataclassMixin):
|
||||
class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
|
||||
actions: Optional[list]
|
||||
author: List[str]
|
||||
building_block_type: Optional[str]
|
||||
@@ -168,10 +170,13 @@ class BaseRuleData(MarshmallowDataclassMixin):
|
||||
# explicitly NOT allowed!
|
||||
# output_index: Optional[str]
|
||||
references: Optional[List[str]]
|
||||
related_integrations: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
|
||||
required_fields: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
|
||||
risk_score: definitions.RiskScore
|
||||
risk_score_mapping: Optional[List[RiskScoreMapping]]
|
||||
rule_id: definitions.UUIDString
|
||||
rule_name_override: Optional[str]
|
||||
setup: Optional[str] = field(metadata=dict(metadata=dict(min_compat="8.3")))
|
||||
severity_mapping: Optional[List[SeverityMapping]]
|
||||
severity: definitions.Severity
|
||||
tags: Optional[List[str]]
|
||||
@@ -186,7 +191,7 @@ class BaseRuleData(MarshmallowDataclassMixin):
|
||||
@classmethod
|
||||
def save_schema(cls):
|
||||
"""Save the schema as a jsonschema."""
|
||||
fields: List[dataclasses.Field] = dataclasses.fields(cls)
|
||||
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(cls)
|
||||
type_field = next(f for f in fields if f.name == "type")
|
||||
rule_type = typing.get_args(type_field.type)[0] if cls != BaseRuleData else "base"
|
||||
schema = cls.jsonschema()
|
||||
@@ -200,6 +205,12 @@ class BaseRuleData(MarshmallowDataclassMixin):
|
||||
def validate_query(self, meta: RuleMeta) -> None:
|
||||
pass
|
||||
|
||||
@cached_property
|
||||
def get_restricted_fields(self) -> Optional[Dict[str, tuple]]:
|
||||
"""Get stack version restricted fields."""
|
||||
fields: List[dataclasses.Field, ...] = list(dataclasses.fields(self))
|
||||
return get_restricted_fields(fields)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryValidator:
|
||||
@@ -536,6 +547,24 @@ class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin):
|
||||
|
||||
return converted
|
||||
|
||||
def check_restricted_fields_compatibility(self) -> Dict[str, dict]:
|
||||
"""Check for compatibility between restricted fields and the min_stack_version of the rule."""
|
||||
default_min_stack = get_min_supported_stack_version(drop_patch=True)
|
||||
if self.metadata.min_stack_version is not None:
|
||||
min_stack = Version(self.metadata.min_stack_version)
|
||||
else:
|
||||
min_stack = default_min_stack
|
||||
restricted = self.data.get_restricted_fields
|
||||
|
||||
invalid = {}
|
||||
for _field, values in restricted.items():
|
||||
if self.data.get(_field) is not None:
|
||||
min_allowed, _ = values
|
||||
if min_stack < min_allowed:
|
||||
invalid[_field] = {'min_stack_version': min_stack, 'min_allowed_version': min_allowed}
|
||||
|
||||
return invalid
|
||||
|
||||
|
||||
@dataclass
|
||||
class TOMLRule:
|
||||
|
||||
@@ -268,6 +268,7 @@ def get_stack_versions(drop_patch=False) -> List[str]:
|
||||
return versions
|
||||
|
||||
|
||||
@cached
|
||||
def get_min_supported_stack_version(drop_patch=False) -> Version:
|
||||
"""Get the minimum defined and supported stack version."""
|
||||
stack_map = load_stack_schema_map()
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
# or more contributor license agreements. Licensed under the Elastic License
|
||||
# 2.0; you may not use this file except in compliance with the Elastic License
|
||||
# 2.0.
|
||||
|
||||
from dataclasses import Field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from ..misc import cached
|
||||
from ..semver import Version
|
||||
|
||||
|
||||
@cached
|
||||
def get_restricted_field(schema_field: Field) -> Tuple[Optional[Version], Optional[Version]]:
|
||||
"""Get an optional min and max compatible versions of a field (from a schema or dataclass)."""
|
||||
# nested get is to support schema fields being passed directly from dataclass or fields in schema class, since
|
||||
# marshmallow_dataclass passes the embedded metadata directly
|
||||
min_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('min_compat')
|
||||
max_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('max_compat')
|
||||
min_compat = Version(min_compat) if min_compat else None
|
||||
max_compat = Version(max_compat) if max_compat else None
|
||||
return min_compat, max_compat
|
||||
|
||||
|
||||
@cached
|
||||
def get_restricted_fields(schema_fields: List[Field]) -> Dict[str, Tuple[Optional[Version], Optional[Version]]]:
|
||||
"""Get a list of optional min and max compatible versions of fields (from a schema or dataclass)."""
|
||||
restricted = {}
|
||||
for _field in schema_fields:
|
||||
min_compat, max_compat = get_restricted_field(_field)
|
||||
if min_compat or max_compat:
|
||||
restricted[_field.name] = (min_compat, max_compat)
|
||||
|
||||
return restricted
|
||||
|
||||
|
||||
@cached
|
||||
def get_incompatible_fields(schema_fields: List[Field], package_version: Version) -> Optional[Dict[str, tuple]]:
|
||||
"""Get a list of fields that are incompatible with the package version."""
|
||||
if not schema_fields:
|
||||
return
|
||||
|
||||
incompatible = {}
|
||||
restricted_fields = get_restricted_fields(schema_fields)
|
||||
for field_name, values in restricted_fields.items():
|
||||
min_compat, max_compat = values
|
||||
|
||||
if min_compat and package_version < min_compat or max_compat and package_version > max_compat:
|
||||
incompatible[field_name] = (min_compat, max_compat)
|
||||
|
||||
return incompatible
|
||||
@@ -677,3 +677,22 @@ class TestIntegrationRules(BaseRuleTest):
|
||||
self.fail(f'{self.rule_str(rule)} expected {integration} config missing\n\n'
|
||||
f'Expected: {note_str}\n\n'
|
||||
f'Actual: {rule.contents.data.note}')
|
||||
|
||||
|
||||
class TestIncompatibleFields(BaseRuleTest):
|
||||
"""Test stack restricted fields do not backport beyond allowable limits."""
|
||||
|
||||
def test_rule_backports_for_restricted_fields(self):
|
||||
"""Test that stack restricted fields will not backport to older rule versions."""
|
||||
invalid_rules = []
|
||||
|
||||
for rule in self.all_rules:
|
||||
invalid = rule.contents.check_restricted_fields_compatibility()
|
||||
if invalid:
|
||||
invalid_rules.append(f'{self.rule_str(rule)} {invalid}')
|
||||
|
||||
if invalid_rules:
|
||||
invalid_str = '\n'.join(invalid_rules)
|
||||
err_msg = 'The following rules have min_stack_versions lower than allowed for restricted fields:\n'
|
||||
err_msg += invalid_str
|
||||
self.fail(err_msg)
|
||||
|
||||
Reference in New Issue
Block a user