Add new required_fields as a build-time restricted field (#2059)

* Add new `require_field` restricted field
* validate new fields against BaseRuleData schema and global constant

Co-authored-by: Terrance DeJesus <terrance.dejesus@elastic.co>
Co-authored-by: brokensound77 <brokensound77@users.noreply.github.com>
This commit is contained in:
Mika Ayenson
2022-07-06 11:49:44 -04:00
committed by GitHub
parent 329530c8c3
commit c76a397969
2 changed files with 121 additions and 19 deletions
+115 -5
View File
@@ -18,7 +18,10 @@ import eql
from marshmallow import ValidationError, validates_schema
import kql
from . import beats
from . import ecs
from . import utils
from .misc import load_current_package_version
from .mixins import MarshmallowDataclassMixin, StackCompatMixin
from .rule_formatter import toml_write, nested_normalize
from .schemas import SCHEMA_DIR, definitions, downgrade, get_stack_schemas, get_min_supported_stack_version
@@ -26,6 +29,7 @@ from .schemas.stack_compat import get_restricted_fields
from .semver import Version
from .utils import cached
BUILD_FIELD_VERSIONS = {"required_fields": (Version('8.3'), None)}
_META_SCHEMA_REQ_DEFAULTS = {}
MIN_FLEET_PACKAGE_VERSION = '7.13.0'
@@ -149,6 +153,12 @@ class FlatThreatMapping(MarshmallowDataclassMixin):
@dataclass(frozen=True)
class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
@dataclass
class RequiredFields:
name: definitions.NonEmptyStr
type: definitions.NonEmptyStr
ecs: bool
actions: Optional[list]
author: List[str]
building_block_type: Optional[str]
@@ -171,7 +181,7 @@ class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
# output_index: Optional[str]
references: Optional[List[str]]
related_integrations: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
required_fields: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
required_fields: Optional[List[RequiredFields]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
risk_score: definitions.RiskScore
risk_score_mapping: Optional[List[RiskScoreMapping]]
rule_id: definitions.UUIDString
@@ -220,9 +230,45 @@ class QueryValidator:
def ast(self) -> Any:
raise NotImplementedError
@property
def unique_fields(self) -> Any:
raise NotImplementedError
def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
raise NotImplementedError()
@cached
def get_required_fields(self, index: str) -> List[dict]:
"""Retrieves fields needed for the query along with type information from the schema."""
current_version = Version(Version(load_current_package_version()) + (0,))
ecs_version = get_stack_schemas()[str(current_version)]['ecs']
beats_version = get_stack_schemas()[str(current_version)]['beats']
ecs_schema = ecs.get_schema(ecs_version)
beat_types, beat_schema, schema = self.get_beats_schema(index or [], beats_version, ecs_version)
required = []
unique_fields = self.unique_fields or []
for fld in unique_fields:
field_type = ecs_schema.get(fld, {}).get('type')
is_ecs = field_type is not None
if beat_schema and not is_ecs:
field_type = beat_schema.get(fld, {}).get('type')
required.append(dict(name=fld, type=field_type or 'unknown', ecs=is_ecs))
return sorted(required, key=lambda f: f['name'])
@cached
def get_beats_schema(self, index: list, beats_version: str, ecs_version: str) -> (list, dict, dict):
"""Get an assembled beats schema."""
beat_types = beats.parse_beats_from_index(index)
beat_schema = beats.get_schema_from_kql(self.ast, beat_types, version=beats_version) if beat_types else None
schema = ecs.get_kql_schema(version=ecs_version, indexes=index, beat_schema=beat_schema)
return beat_types, beat_schema, schema
@dataclass(frozen=True)
class QueryRuleData(BaseRuleData):
@@ -251,6 +297,18 @@ class QueryRuleData(BaseRuleData):
if validator is not None:
return validator.ast
@cached_property
def unique_fields(self):
validator = self.validator
if validator is not None:
return validator.unique_fields
@cached
def get_required_fields(self, index: str) -> List[dict]:
validator = self.validator
if validator is not None:
return validator.get_required_fields(index or [])
@dataclass(frozen=True)
class MachineLearningRuleData(BaseRuleData):
@@ -438,8 +496,7 @@ class BaseRuleContents(ABC):
return version + 1 if self.is_dirty else version
@staticmethod
def _post_dict_transform(obj: dict) -> dict:
def _post_dict_transform(self, obj: dict) -> dict:
"""Transform the converted API in place before sending to Kibana."""
# cleanup the whitespace in the rule
@@ -515,6 +572,59 @@ class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin):
def type(self) -> str:
return self.data.type
def _post_dict_transform(self, obj: dict) -> dict:
"""Transform the converted API in place before sending to Kibana."""
super()._post_dict_transform(obj)
self.add_related_integrations(obj)
self.add_required_fields(obj)
self.add_setup(obj)
# validate new fields against the schema
rule_type = obj['type']
subclass = self.get_data_subclass(rule_type)
subclass.from_dict(obj)
return obj
def add_related_integrations(self, obj: dict) -> None:
"""Add restricted field related_integrations to the obj."""
# field_name = "related_integrations"
...
def add_required_fields(self, obj: dict) -> None:
"""Add restricted field required_fields to the obj, derived from the query AST."""
if isinstance(self.data, QueryRuleData) and self.data.language != 'lucene':
index = obj.get('index') or []
required_fields = self.data.get_required_fields(index)
else:
required_fields = []
field_name = "required_fields"
if self.check_restricted_field_version(field_name=field_name):
obj.setdefault(field_name, required_fields)
def add_setup(self, obj: dict) -> None:
"""Add restricted field setup to the obj."""
# field_name = "setup"
...
def check_explicit_restricted_field_version(self, field_name: str) -> bool:
"""Explicitly check restricted fields against global min and max versions."""
min_stack, max_stack = BUILD_FIELD_VERSIONS[field_name]
return self.compare_field_versions(min_stack, max_stack)
def check_restricted_field_version(self, field_name: str) -> bool:
"""Check restricted fields against schema min and max versions."""
min_stack, max_stack = self.data.get_restricted_fields.get(field_name)
return self.compare_field_versions(min_stack, max_stack)
def compare_field_versions(self, min_stack: Version, max_stack: Version) -> bool:
"""Check current rule version is witihin min and max stack versions."""
current_version = Version(load_current_package_version())
max_stack = max_stack or current_version
return Version(min_stack) <= current_version >= Version(max_stack)
@validates_schema
def validate_query(self, value: dict, **kwargs):
"""Validate queries by calling into the validator for the relevant method."""
@@ -540,11 +650,11 @@ class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin):
def to_api_format(self, include_version=True) -> dict:
"""Convert the TOML rule to the API format."""
converted = self.data.to_dict()
converted = self._post_dict_transform(converted)
if include_version:
converted["version"] = self.autobumped_version
converted = self._post_dict_transform(converted)
return converted
def check_restricted_fields_compatibility(self) -> Dict[str, dict]:
+6 -14
View File
@@ -10,7 +10,7 @@ from typing import List, Optional, Union
import eql
import kql
from . import ecs, beats
from . import ecs
from .rule import QueryValidator, QueryRuleData, RuleMeta
@@ -21,7 +21,7 @@ class KQLValidator(QueryValidator):
def ast(self) -> kql.ast.Expression:
return kql.parse(self.query)
@property
@cached_property
def unique_fields(self) -> List[str]:
return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field)))
@@ -29,9 +29,7 @@ class KQLValidator(QueryValidator):
return kql.to_eql(self.query)
def validate(self, data: QueryRuleData, meta: RuleMeta) -> None:
"""Static method to validate the query, called from the parent which contains [metadata] information."""
ast = self.ast
"""Validate the query, called from the parent which contains [metadata] information."""
if meta.query_schema_validation is False or meta.maturity == "deprecated":
# syntax only, which is done via self.ast
return
@@ -41,9 +39,7 @@ class KQLValidator(QueryValidator):
ecs_version = mapping['ecs']
err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}'
beat_types = beats.parse_beats_from_index(data.index)
beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None
schema = ecs.get_kql_schema(version=ecs_version, indexes=data.index or [], beat_schema=beat_schema)
beat_types, beat_schema, schema = self.get_beats_schema(data.index or [], beats_version, ecs_version)
try:
kql.parse(self.query, schema=schema)
@@ -73,14 +69,12 @@ class EQLValidator(QueryValidator):
return [f for f in self.unique_fields if elasticsearch_type_family(eql_schema.kql_schema.get(f)) == 'text']
@property
@cached_property
def unique_fields(self) -> List[str]:
return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field)))
def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
"""Validate an EQL query while checking TOMLRule."""
ast = self.ast
if meta.query_schema_validation is False or meta.maturity == "deprecated":
# syntax only, which is done via self.ast
return
@@ -90,9 +84,7 @@ class EQLValidator(QueryValidator):
ecs_version = mapping['ecs']
err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}'
beat_types = beats.parse_beats_from_index(data.index)
beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None
schema = ecs.get_kql_schema(version=ecs_version, indexes=data.index or [], beat_schema=beat_schema)
beat_types, beat_schema, schema = self.get_beats_schema(data.index or [], beats_version, ecs_version)
eql_schema = ecs.KqlSchema2Eql(schema)
try: