Add new required_fields as a build-time restricted field (#2059)
* Add new `require_field` restricted field * validate new fields against BaseRuleData schema and global constant Co-authored-by: Terrance DeJesus <terrance.dejesus@elastic.co> Co-authored-by: brokensound77 <brokensound77@users.noreply.github.com>
This commit is contained in:
+115
-5
@@ -18,7 +18,10 @@ import eql
|
||||
from marshmallow import ValidationError, validates_schema
|
||||
|
||||
import kql
|
||||
from . import beats
|
||||
from . import ecs
|
||||
from . import utils
|
||||
from .misc import load_current_package_version
|
||||
from .mixins import MarshmallowDataclassMixin, StackCompatMixin
|
||||
from .rule_formatter import toml_write, nested_normalize
|
||||
from .schemas import SCHEMA_DIR, definitions, downgrade, get_stack_schemas, get_min_supported_stack_version
|
||||
@@ -26,6 +29,7 @@ from .schemas.stack_compat import get_restricted_fields
|
||||
from .semver import Version
|
||||
from .utils import cached
|
||||
|
||||
BUILD_FIELD_VERSIONS = {"required_fields": (Version('8.3'), None)}
|
||||
_META_SCHEMA_REQ_DEFAULTS = {}
|
||||
MIN_FLEET_PACKAGE_VERSION = '7.13.0'
|
||||
|
||||
@@ -149,6 +153,12 @@ class FlatThreatMapping(MarshmallowDataclassMixin):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
|
||||
@dataclass
|
||||
class RequiredFields:
|
||||
name: definitions.NonEmptyStr
|
||||
type: definitions.NonEmptyStr
|
||||
ecs: bool
|
||||
|
||||
actions: Optional[list]
|
||||
author: List[str]
|
||||
building_block_type: Optional[str]
|
||||
@@ -171,7 +181,7 @@ class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
|
||||
# output_index: Optional[str]
|
||||
references: Optional[List[str]]
|
||||
related_integrations: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
|
||||
required_fields: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
|
||||
required_fields: Optional[List[RequiredFields]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
|
||||
risk_score: definitions.RiskScore
|
||||
risk_score_mapping: Optional[List[RiskScoreMapping]]
|
||||
rule_id: definitions.UUIDString
|
||||
@@ -220,9 +230,45 @@ class QueryValidator:
|
||||
def ast(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def unique_fields(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@cached
|
||||
def get_required_fields(self, index: str) -> List[dict]:
|
||||
"""Retrieves fields needed for the query along with type information from the schema."""
|
||||
current_version = Version(Version(load_current_package_version()) + (0,))
|
||||
ecs_version = get_stack_schemas()[str(current_version)]['ecs']
|
||||
beats_version = get_stack_schemas()[str(current_version)]['beats']
|
||||
ecs_schema = ecs.get_schema(ecs_version)
|
||||
|
||||
beat_types, beat_schema, schema = self.get_beats_schema(index or [], beats_version, ecs_version)
|
||||
|
||||
required = []
|
||||
unique_fields = self.unique_fields or []
|
||||
|
||||
for fld in unique_fields:
|
||||
field_type = ecs_schema.get(fld, {}).get('type')
|
||||
is_ecs = field_type is not None
|
||||
|
||||
if beat_schema and not is_ecs:
|
||||
field_type = beat_schema.get(fld, {}).get('type')
|
||||
|
||||
required.append(dict(name=fld, type=field_type or 'unknown', ecs=is_ecs))
|
||||
|
||||
return sorted(required, key=lambda f: f['name'])
|
||||
|
||||
@cached
|
||||
def get_beats_schema(self, index: list, beats_version: str, ecs_version: str) -> (list, dict, dict):
|
||||
"""Get an assembled beats schema."""
|
||||
beat_types = beats.parse_beats_from_index(index)
|
||||
beat_schema = beats.get_schema_from_kql(self.ast, beat_types, version=beats_version) if beat_types else None
|
||||
schema = ecs.get_kql_schema(version=ecs_version, indexes=index, beat_schema=beat_schema)
|
||||
return beat_types, beat_schema, schema
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QueryRuleData(BaseRuleData):
|
||||
@@ -251,6 +297,18 @@ class QueryRuleData(BaseRuleData):
|
||||
if validator is not None:
|
||||
return validator.ast
|
||||
|
||||
@cached_property
|
||||
def unique_fields(self):
|
||||
validator = self.validator
|
||||
if validator is not None:
|
||||
return validator.unique_fields
|
||||
|
||||
@cached
|
||||
def get_required_fields(self, index: str) -> List[dict]:
|
||||
validator = self.validator
|
||||
if validator is not None:
|
||||
return validator.get_required_fields(index or [])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MachineLearningRuleData(BaseRuleData):
|
||||
@@ -438,8 +496,7 @@ class BaseRuleContents(ABC):
|
||||
|
||||
return version + 1 if self.is_dirty else version
|
||||
|
||||
@staticmethod
|
||||
def _post_dict_transform(obj: dict) -> dict:
|
||||
def _post_dict_transform(self, obj: dict) -> dict:
|
||||
"""Transform the converted API in place before sending to Kibana."""
|
||||
|
||||
# cleanup the whitespace in the rule
|
||||
@@ -515,6 +572,59 @@ class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin):
|
||||
def type(self) -> str:
|
||||
return self.data.type
|
||||
|
||||
def _post_dict_transform(self, obj: dict) -> dict:
|
||||
"""Transform the converted API in place before sending to Kibana."""
|
||||
super()._post_dict_transform(obj)
|
||||
|
||||
self.add_related_integrations(obj)
|
||||
self.add_required_fields(obj)
|
||||
self.add_setup(obj)
|
||||
|
||||
# validate new fields against the schema
|
||||
rule_type = obj['type']
|
||||
subclass = self.get_data_subclass(rule_type)
|
||||
subclass.from_dict(obj)
|
||||
|
||||
return obj
|
||||
|
||||
def add_related_integrations(self, obj: dict) -> None:
|
||||
"""Add restricted field related_integrations to the obj."""
|
||||
# field_name = "related_integrations"
|
||||
...
|
||||
|
||||
def add_required_fields(self, obj: dict) -> None:
|
||||
"""Add restricted field required_fields to the obj, derived from the query AST."""
|
||||
if isinstance(self.data, QueryRuleData) and self.data.language != 'lucene':
|
||||
index = obj.get('index') or []
|
||||
required_fields = self.data.get_required_fields(index)
|
||||
else:
|
||||
required_fields = []
|
||||
|
||||
field_name = "required_fields"
|
||||
if self.check_restricted_field_version(field_name=field_name):
|
||||
obj.setdefault(field_name, required_fields)
|
||||
|
||||
def add_setup(self, obj: dict) -> None:
|
||||
"""Add restricted field setup to the obj."""
|
||||
# field_name = "setup"
|
||||
...
|
||||
|
||||
def check_explicit_restricted_field_version(self, field_name: str) -> bool:
|
||||
"""Explicitly check restricted fields against global min and max versions."""
|
||||
min_stack, max_stack = BUILD_FIELD_VERSIONS[field_name]
|
||||
return self.compare_field_versions(min_stack, max_stack)
|
||||
|
||||
def check_restricted_field_version(self, field_name: str) -> bool:
|
||||
"""Check restricted fields against schema min and max versions."""
|
||||
min_stack, max_stack = self.data.get_restricted_fields.get(field_name)
|
||||
return self.compare_field_versions(min_stack, max_stack)
|
||||
|
||||
def compare_field_versions(self, min_stack: Version, max_stack: Version) -> bool:
|
||||
"""Check current rule version is witihin min and max stack versions."""
|
||||
current_version = Version(load_current_package_version())
|
||||
max_stack = max_stack or current_version
|
||||
return Version(min_stack) <= current_version >= Version(max_stack)
|
||||
|
||||
@validates_schema
|
||||
def validate_query(self, value: dict, **kwargs):
|
||||
"""Validate queries by calling into the validator for the relevant method."""
|
||||
@@ -540,11 +650,11 @@ class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin):
|
||||
def to_api_format(self, include_version=True) -> dict:
|
||||
"""Convert the TOML rule to the API format."""
|
||||
converted = self.data.to_dict()
|
||||
converted = self._post_dict_transform(converted)
|
||||
|
||||
if include_version:
|
||||
converted["version"] = self.autobumped_version
|
||||
|
||||
converted = self._post_dict_transform(converted)
|
||||
|
||||
return converted
|
||||
|
||||
def check_restricted_fields_compatibility(self) -> Dict[str, dict]:
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import List, Optional, Union
|
||||
import eql
|
||||
|
||||
import kql
|
||||
from . import ecs, beats
|
||||
from . import ecs
|
||||
from .rule import QueryValidator, QueryRuleData, RuleMeta
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class KQLValidator(QueryValidator):
|
||||
def ast(self) -> kql.ast.Expression:
|
||||
return kql.parse(self.query)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def unique_fields(self) -> List[str]:
|
||||
return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field)))
|
||||
|
||||
@@ -29,9 +29,7 @@ class KQLValidator(QueryValidator):
|
||||
return kql.to_eql(self.query)
|
||||
|
||||
def validate(self, data: QueryRuleData, meta: RuleMeta) -> None:
|
||||
"""Static method to validate the query, called from the parent which contains [metadata] information."""
|
||||
ast = self.ast
|
||||
|
||||
"""Validate the query, called from the parent which contains [metadata] information."""
|
||||
if meta.query_schema_validation is False or meta.maturity == "deprecated":
|
||||
# syntax only, which is done via self.ast
|
||||
return
|
||||
@@ -41,9 +39,7 @@ class KQLValidator(QueryValidator):
|
||||
ecs_version = mapping['ecs']
|
||||
err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}'
|
||||
|
||||
beat_types = beats.parse_beats_from_index(data.index)
|
||||
beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None
|
||||
schema = ecs.get_kql_schema(version=ecs_version, indexes=data.index or [], beat_schema=beat_schema)
|
||||
beat_types, beat_schema, schema = self.get_beats_schema(data.index or [], beats_version, ecs_version)
|
||||
|
||||
try:
|
||||
kql.parse(self.query, schema=schema)
|
||||
@@ -73,14 +69,12 @@ class EQLValidator(QueryValidator):
|
||||
|
||||
return [f for f in self.unique_fields if elasticsearch_type_family(eql_schema.kql_schema.get(f)) == 'text']
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def unique_fields(self) -> List[str]:
|
||||
return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field)))
|
||||
|
||||
def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
|
||||
"""Validate an EQL query while checking TOMLRule."""
|
||||
ast = self.ast
|
||||
|
||||
if meta.query_schema_validation is False or meta.maturity == "deprecated":
|
||||
# syntax only, which is done via self.ast
|
||||
return
|
||||
@@ -90,9 +84,7 @@ class EQLValidator(QueryValidator):
|
||||
ecs_version = mapping['ecs']
|
||||
err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}'
|
||||
|
||||
beat_types = beats.parse_beats_from_index(data.index)
|
||||
beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None
|
||||
schema = ecs.get_kql_schema(version=ecs_version, indexes=data.index or [], beat_schema=beat_schema)
|
||||
beat_types, beat_schema, schema = self.get_beats_schema(data.index or [], beats_version, ecs_version)
|
||||
eql_schema = ecs.KqlSchema2Eql(schema)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user