[Bug] Add Required to the Annotation (#5159)
* Add Required to the Annotation * Additional required fields * remove nonempty sting validation * Required Types via Annotated and Dataclass * remove space * Remove inline comment * Switch to getting a list * Fix typo and sort --------- Co-authored-by: Mika Ayenson, PhD <Mikaayenson@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import kql # type: ignore[reportMissingTypeStubs]
|
||||
from . import ecs
|
||||
from .attack import build_threat_map_entry, matrix, tactics
|
||||
from .config import parse_rules_config
|
||||
from .mixins import get_dataclass_required_fields
|
||||
from .rule import BYPASS_VERSION_LOCK, TOMLRule, TOMLRuleContents
|
||||
from .rule_loader import DEFAULT_PREBUILT_BBR_DIRS, DEFAULT_PREBUILT_RULES_DIRS, RuleCollection, dict_filter
|
||||
from .schemas import definitions
|
||||
@@ -166,9 +167,10 @@ def rule_prompt( # noqa: PLR0912, PLR0913, PLR0915
|
||||
)
|
||||
|
||||
target_data_subclass = TOMLRuleContents.get_data_subclass(rule_type_val)
|
||||
required_fields = get_dataclass_required_fields(target_data_subclass)
|
||||
schema = target_data_subclass.jsonschema()
|
||||
props = schema["properties"]
|
||||
required_fields = schema.get("required", []) + additional_required
|
||||
required_fields = sorted(required_fields + additional_required)
|
||||
contents: dict[str, Any] = {}
|
||||
skipped: list[str] = []
|
||||
|
||||
|
||||
@@ -8,13 +8,14 @@
|
||||
import dataclasses
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, get_type_hints
|
||||
|
||||
import marshmallow
|
||||
import marshmallow_dataclass
|
||||
import marshmallow_dataclass.union_field
|
||||
import marshmallow_jsonschema # type: ignore[reportMissingTypeStubs]
|
||||
import marshmallow_union # type: ignore[reportMissingTypeStubs]
|
||||
import typing_inspect # type: ignore[reportMissingTypeStubs]
|
||||
from marshmallow import Schema, ValidationError, validates_schema
|
||||
from marshmallow import fields as marshmallow_fields
|
||||
from semver import Version
|
||||
@@ -38,6 +39,28 @@ def _strip_none_from_dict(obj: Any) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def get_dataclass_required_fields(cls: Any) -> list[str]:
|
||||
"""Get required fields based on both dataclass and type Annotations."""
|
||||
required_fields: list[str] = []
|
||||
hints = get_type_hints(cls, include_extras=True)
|
||||
marshmallow_schema = marshmallow_dataclass.class_schema(cls)()
|
||||
for dc_field in dataclasses.fields(cls):
|
||||
hint = hints.get(dc_field.name)
|
||||
if not hint:
|
||||
continue
|
||||
|
||||
mm_field = marshmallow_schema.fields.get(dc_field.name)
|
||||
if mm_field is None:
|
||||
continue
|
||||
if dc_field.default is not dataclasses.MISSING:
|
||||
continue
|
||||
if getattr(dc_field, "default_factory", dataclasses.MISSING) is not dataclasses.MISSING:
|
||||
continue
|
||||
if not typing_inspect.is_optional_type(hint) or mm_field.required is True: # type: ignore[reportUnknownVariableType]
|
||||
required_fields.append(dc_field.name)
|
||||
return required_fields
|
||||
|
||||
|
||||
def patch_jsonschema(obj: Any) -> dict[str, Any]:
|
||||
"""Patch marshmallow-jsonschema output to look more like JSL."""
|
||||
|
||||
@@ -264,5 +287,4 @@ class PatchedJSONSchema(marshmallow_jsonschema.JSONSchema):
|
||||
default=field.default, # type: ignore[reportUnknownMemberType]
|
||||
allow_none=field.allow_none,
|
||||
)
|
||||
|
||||
return super()._get_schema_for_field(obj, field) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "detection_rules"
|
||||
version = "1.4.6"
|
||||
version = "1.4.7"
|
||||
description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
|
||||
Reference in New Issue
Block a user