diff --git a/detection_rules/version_lock.py b/detection_rules/version_lock.py index 564efe3cc..af067c338 100644 --- a/detection_rules/version_lock.py +++ b/detection_rules/version_lock.py @@ -24,16 +24,18 @@ MIN_LOCK_VERSION_DEFAULT = Version("7.13.0") @dataclass(frozen=True) -class VersionLockFileEntry(MarshmallowDataclassMixin): - """Schema for a rule entry in the version lock.""" +class BaseEntry: rule_name: definitions.RuleName sha256: definitions.Sha256 type: definitions.RuleType version: definitions.PositiveInteger - min_stack_version: Optional[definitions.SemVer] - # TODO: need to exclude nested 'previous' - previous: Optional[Dict[definitions.SemVer, 'VersionLockFileEntry']] + +@dataclass(frozen=True) +class VersionLockFileEntry(MarshmallowDataclassMixin, BaseEntry): + """Schema for a rule entry in the version lock.""" + min_stack_version: Optional[definitions.SemVer] + previous: Optional[Dict[definitions.SemVer, BaseEntry]] @dataclass(frozen=True) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index dce9042da..347fb87de 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -9,12 +9,13 @@ import unittest import uuid import eql - from detection_rules import utils from detection_rules.packaging import load_current_package_version from detection_rules.rule import TOMLRuleContents from detection_rules.schemas import downgrade from detection_rules.semver import Version +from detection_rules.version_lock import VersionLockFile +from marshmallow import ValidationError class TestSchemas(unittest.TestCase): @@ -236,6 +237,49 @@ class TestSchemas(unittest.TestCase): """) +class TestVersionLockSchema(unittest.TestCase): + """Test that the version lock has proper entries.""" + + @classmethod + def setUpClass(cls): + cls.version_lock_contents = { + "33f306e8-417c-411b-965c-c2812d6d3f4d": { + "rule_name": "Remote File Download via PowerShell", + "sha256": "8679cd72bf85b67dde3dcfdaba749ed1fa6560bca5efd03ed41c76a500ce31d6", + "type": "eql", + "version": 4 + }, + "34fde489-94b0-4500-a76f-b8a157cf9269": { + "min_stack_version": "8.2", + "previous": { + "7.13.0": { + "rule_name": "Telnet Port Activity", + "sha256": "3dd4a438c915920e6ddb0a5212603af5d94fb8a6b51a32f223d930d7e3becb89", + "type": "query", + "version": 9 + } + }, + "rule_name": "Telnet Port Activity", + "sha256": "b0bdfa73639226fb83eadc0303ad1801e0707743f96a36209aa58228d3bf6a89", + "type": "query", + "version": 10 + } + } + + def test_version_lock_no_previous(self): + """Pass field validation on version lock without nested previous fields""" + version_lock_contents = copy.deepcopy(self.version_lock_contents) + VersionLockFile.from_dict(dict(data=version_lock_contents)) + + def test_version_lock_has_nested_previous(self): + """Fail field validation on version lock with nested previous fields""" + version_lock_contents = copy.deepcopy(self.version_lock_contents) + with self.assertRaises(ValidationError): + previous = version_lock_contents["34fde489-94b0-4500-a76f-b8a157cf9269"]["previous"] + version_lock_contents["34fde489-94b0-4500-a76f-b8a157cf9269"]["previous"]["previous"] = previous + VersionLockFile.from_dict(dict(data=version_lock_contents)) + + class TestVersions(unittest.TestCase): """Test that schema versioning aligns."""