diff --git a/detection_rules/cli_utils.py b/detection_rules/cli_utils.py index 49fd98e34..c66c243f4 100644 --- a/detection_rules/cli_utils.py +++ b/detection_rules/cli_utils.py @@ -18,6 +18,7 @@ import kql # type: ignore[reportMissingTypeStubs] from . import ecs from .attack import build_threat_map_entry, matrix, tactics from .config import parse_rules_config +from .mixins import get_dataclass_required_fields from .rule import BYPASS_VERSION_LOCK, TOMLRule, TOMLRuleContents from .rule_loader import DEFAULT_PREBUILT_BBR_DIRS, DEFAULT_PREBUILT_RULES_DIRS, RuleCollection, dict_filter from .schemas import definitions @@ -166,9 +167,10 @@ def rule_prompt( # noqa: PLR0912, PLR0913, PLR0915 ) target_data_subclass = TOMLRuleContents.get_data_subclass(rule_type_val) + required_fields = get_dataclass_required_fields(target_data_subclass) schema = target_data_subclass.jsonschema() props = schema["properties"] - required_fields = schema.get("required", []) + additional_required + required_fields = sorted(required_fields + additional_required) contents: dict[str, Any] = {} skipped: list[str] = [] diff --git a/detection_rules/mixins.py b/detection_rules/mixins.py index e02a2ce37..e2f200309 100644 --- a/detection_rules/mixins.py +++ b/detection_rules/mixins.py @@ -8,13 +8,14 @@ import dataclasses import json from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, get_type_hints import marshmallow import marshmallow_dataclass import marshmallow_dataclass.union_field import marshmallow_jsonschema # type: ignore[reportMissingTypeStubs] import marshmallow_union # type: ignore[reportMissingTypeStubs] +import typing_inspect # type: ignore[reportMissingTypeStubs] from marshmallow import Schema, ValidationError, validates_schema from marshmallow import fields as marshmallow_fields from semver import Version @@ -38,6 +39,28 @@ def _strip_none_from_dict(obj: Any) -> Any: return obj +def get_dataclass_required_fields(cls: Any) -> list[str]: + """Get required fields based on both dataclass and type Annotations.""" + required_fields: list[str] = [] + hints = get_type_hints(cls, include_extras=True) + marshmallow_schema = marshmallow_dataclass.class_schema(cls)() + for dc_field in dataclasses.fields(cls): + hint = hints.get(dc_field.name) + if not hint: + continue + + mm_field = marshmallow_schema.fields.get(dc_field.name) + if mm_field is None: + continue + if dc_field.default is not dataclasses.MISSING: + continue + if getattr(dc_field, "default_factory", dataclasses.MISSING) is not dataclasses.MISSING: + continue + if not typing_inspect.is_optional_type(hint) or mm_field.required is True: # type: ignore[reportUnknownVariableType] + required_fields.append(dc_field.name) + return required_fields + + def patch_jsonschema(obj: Any) -> dict[str, Any]: """Patch marshmallow-jsonschema output to look more like JSL.""" @@ -264,5 +287,4 @@ class PatchedJSONSchema(marshmallow_jsonschema.JSONSchema): default=field.default, # type: ignore[reportUnknownMemberType] allow_none=field.allow_none, ) - return super()._get_schema_for_field(obj, field) # type: ignore[reportUnknownMemberType] diff --git a/pyproject.toml b/pyproject.toml index 8df203798..1610d92ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "detection_rules" -version = "1.4.6" +version = "1.4.7" description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine." readme = "README.md" requires-python = ">=3.12"