From 34ebcec679ca46c5c067ca6f917a2755396adfe3 Mon Sep 17 00:00:00 2001 From: eric-forte-elastic <119343520+eric-forte-elastic@users.noreply.github.com> Date: Tue, 5 Sep 2023 15:27:04 -0400 Subject: [PATCH] Added unit test (#3038) * Added unit test * removed print from unit test * fixed linting * Updated to put validation in init * Updated for cleanliness * removed Literal import --- rta/__init__.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/rta/__init__.py b/rta/__init__.py index aa4ba6135..9db5f72b6 100644 --- a/rta/__init__.py +++ b/rta/__init__.py @@ -11,8 +11,9 @@ from typing import Dict, List, Optional from . import common +# Definitions CURRENT_DIR = Path(__file__).resolve().parent - +RULE_META_KEYS = ["rule_id", "rule_name"] @dataclass class RtaMetadata: @@ -20,27 +21,36 @@ class RtaMetadata: uuid: str platforms: List[str] - path: Path = field(init=False) name: str = field(init=False) - endpoint: Optional[List[dict]] = None - siem: Optional[List[dict]] = None + endpoint: Optional[List[Dict[str, str]]] = None + siem: Optional[List[Dict[str, str]]] = None techniques: Optional[List[str]] = None def __post_init__(self): """Set the path and name based on the callee and check for platforms.""" - # set the path of the callee + # Set the path of the callee for frame in inspect.stack(): self.path = Path(frame.filename) self.name = self.path.name if frame.function == "" and valid_rta_file(self.path): break - # check for valid platforms + # Check for valid platforms if not self.platforms and (self.endpoint or self.siem): raise ValueError(f"RTA {self.name} has no platforms specified but has rule info provided.") + # Check for valid rule metadata + self._validate_rule_metadata(self.endpoint, "endpoint") + self._validate_rule_metadata(self.siem, "siem") + + def _validate_rule_metadata(self, rules: Optional[List[Dict[str, str]]], field_name: str): + """Check for valid rule metadata""" + if rules: + for rule in rules: + if sorted(rule.keys()) != RULE_META_KEYS: + raise ValueError(f"RTA {self.name} has invalid {field_name} field in metadata.") def valid_rta_file(file_path: str) -> bool: return file_path.stem not in ["init", "common", "main"] and not file_path.name.startswith("_")