Add support for restricted fields (#2053)

* Add support for restricted fields (fields valid only in min/max stack versions)
* add test to ensure rule backports wont exceed min compat
This commit is contained in:
Justin Ibarra
2022-06-27 10:02:15 -05:00
committed by GitHub
parent 4ef1a1a627
commit cc01d3fb1a
5 changed files with 129 additions and 6 deletions
+25 -2
View File
@@ -6,16 +6,19 @@
"""Generic mixin classes."""
from pathlib import Path
from typing import TypeVar, Type, Optional, Any
from typing import Any, Optional, TypeVar, Type
import json
import marshmallow_dataclass
import marshmallow_dataclass.union_field
import marshmallow_jsonschema
import marshmallow_union
from marshmallow import Schema, ValidationError, fields
from marshmallow import Schema, ValidationError, fields, validates_schema
from .misc import load_current_package_version
from .schemas import definitions
from .schemas.stack_compat import get_incompatible_fields
from .semver import Version
from .utils import cached, dict_hash
T = TypeVar('T')
@@ -171,6 +174,26 @@ class LockDataclassMixin:
path.write_text(json.dumps(contents, indent=2, sort_keys=True))
class StackCompatMixin:
"""Mixin to restrict schema compatibility to defined stack versions."""
@validates_schema
def validate_field_compatibility(self, data: dict, **kwargs):
"""Verify stack-specific fields are properly applied to schema."""
package_version = Version(load_current_package_version())
schema_fields = getattr(self, 'fields', {})
incompatible = get_incompatible_fields(list(schema_fields.values()), package_version)
if not incompatible:
return
package_version = load_current_package_version()
for field, bounds in incompatible.items():
min_compat, max_compat = bounds
if data.get(field) is not None:
raise ValidationError(f'Invalid field: "{field}" for stack version: {package_version}, '
f'min compatibility: {min_compat}, max compatibility: {max_compat}')
class PatchedJSONSchema(marshmallow_jsonschema.JSONSchema):
# Patch marshmallow-jsonschema to support marshmallow-dataclass[union]
+33 -4
View File
@@ -11,7 +11,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import Literal, Union, Optional, List, Any, Dict
from typing import Literal, Union, Optional, List, Any, Dict, Tuple
from uuid import uuid4
import eql
@@ -19,9 +19,11 @@ from marshmallow import ValidationError, validates_schema
import kql
from . import utils
from .mixins import MarshmallowDataclassMixin
from .mixins import MarshmallowDataclassMixin, StackCompatMixin
from .rule_formatter import toml_write, nested_normalize
from .schemas import SCHEMA_DIR, definitions, downgrade, get_stack_schemas, get_min_supported_stack_version
from .schemas.stack_compat import get_restricted_fields
from .semver import Version
from .utils import cached
_META_SCHEMA_REQ_DEFAULTS = {}
@@ -146,7 +148,7 @@ class FlatThreatMapping(MarshmallowDataclassMixin):
@dataclass(frozen=True)
class BaseRuleData(MarshmallowDataclassMixin):
class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
actions: Optional[list]
author: List[str]
building_block_type: Optional[str]
@@ -168,10 +170,13 @@ class BaseRuleData(MarshmallowDataclassMixin):
# explicitly NOT allowed!
# output_index: Optional[str]
references: Optional[List[str]]
related_integrations: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
required_fields: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
risk_score: definitions.RiskScore
risk_score_mapping: Optional[List[RiskScoreMapping]]
rule_id: definitions.UUIDString
rule_name_override: Optional[str]
setup: Optional[str] = field(metadata=dict(metadata=dict(min_compat="8.3")))
severity_mapping: Optional[List[SeverityMapping]]
severity: definitions.Severity
tags: Optional[List[str]]
@@ -186,7 +191,7 @@ class BaseRuleData(MarshmallowDataclassMixin):
@classmethod
def save_schema(cls):
"""Save the schema as a jsonschema."""
fields: List[dataclasses.Field] = dataclasses.fields(cls)
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(cls)
type_field = next(f for f in fields if f.name == "type")
rule_type = typing.get_args(type_field.type)[0] if cls != BaseRuleData else "base"
schema = cls.jsonschema()
@@ -200,6 +205,12 @@ class BaseRuleData(MarshmallowDataclassMixin):
def validate_query(self, meta: RuleMeta) -> None:
pass
@cached_property
def get_restricted_fields(self) -> Optional[Dict[str, tuple]]:
"""Get stack version restricted fields."""
fields: List[dataclasses.Field, ...] = list(dataclasses.fields(self))
return get_restricted_fields(fields)
@dataclass
class QueryValidator:
@@ -536,6 +547,24 @@ class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin):
return converted
def check_restricted_fields_compatibility(self) -> Dict[str, dict]:
"""Check for compatibility between restricted fields and the min_stack_version of the rule."""
default_min_stack = get_min_supported_stack_version(drop_patch=True)
if self.metadata.min_stack_version is not None:
min_stack = Version(self.metadata.min_stack_version)
else:
min_stack = default_min_stack
restricted = self.data.get_restricted_fields
invalid = {}
for _field, values in restricted.items():
if self.data.get(_field) is not None:
min_allowed, _ = values
if min_stack < min_allowed:
invalid[_field] = {'min_stack_version': min_stack, 'min_allowed_version': min_allowed}
return invalid
@dataclass
class TOMLRule:
+1
View File
@@ -268,6 +268,7 @@ def get_stack_versions(drop_patch=False) -> List[str]:
return versions
@cached
def get_min_supported_stack_version(drop_patch=False) -> Version:
"""Get the minimum defined and supported stack version."""
stack_map = load_stack_schema_map()
+51
View File
@@ -0,0 +1,51 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
from dataclasses import Field
from typing import Dict, List, Optional, Tuple
from ..misc import cached
from ..semver import Version
@cached
def get_restricted_field(schema_field: Field) -> Tuple[Optional[Version], Optional[Version]]:
"""Get an optional min and max compatible versions of a field (from a schema or dataclass)."""
# nested get is to support schema fields being passed directly from dataclass or fields in schema class, since
# marshmallow_dataclass passes the embedded metadata directly
min_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('min_compat')
max_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('max_compat')
min_compat = Version(min_compat) if min_compat else None
max_compat = Version(max_compat) if max_compat else None
return min_compat, max_compat
@cached
def get_restricted_fields(schema_fields: List[Field]) -> Dict[str, Tuple[Optional[Version], Optional[Version]]]:
"""Get a list of optional min and max compatible versions of fields (from a schema or dataclass)."""
restricted = {}
for _field in schema_fields:
min_compat, max_compat = get_restricted_field(_field)
if min_compat or max_compat:
restricted[_field.name] = (min_compat, max_compat)
return restricted
@cached
def get_incompatible_fields(schema_fields: List[Field], package_version: Version) -> Optional[Dict[str, tuple]]:
"""Get a list of fields that are incompatible with the package version."""
if not schema_fields:
return
incompatible = {}
restricted_fields = get_restricted_fields(schema_fields)
for field_name, values in restricted_fields.items():
min_compat, max_compat = values
if min_compat and package_version < min_compat or max_compat and package_version > max_compat:
incompatible[field_name] = (min_compat, max_compat)
return incompatible
+19
View File
@@ -677,3 +677,22 @@ class TestIntegrationRules(BaseRuleTest):
self.fail(f'{self.rule_str(rule)} expected {integration} config missing\n\n'
f'Expected: {note_str}\n\n'
f'Actual: {rule.contents.data.note}')
class TestIncompatibleFields(BaseRuleTest):
"""Test stack restricted fields do not backport beyond allowable limits."""
def test_rule_backports_for_restricted_fields(self):
"""Test that stack restricted fields will not backport to older rule versions."""
invalid_rules = []
for rule in self.all_rules:
invalid = rule.contents.check_restricted_fields_compatibility()
if invalid:
invalid_rules.append(f'{self.rule_str(rule)} {invalid}')
if invalid_rules:
invalid_str = '\n'.join(invalid_rules)
err_msg = 'The following rules have min_stack_versions lower than allowed for restricted fields:\n'
err_msg += invalid_str
self.fail(err_msg)