Files
sigma-rules/detection_rules/mixins.py
T
Ross Wolf 6ed1a39efe Add a RuleCollection object instead of a "loader" module (#1063)
* Add a RuleCollection object instead of a "loader" module
* Remove legacy loader code
* Remove more legacy loader
* Freeze the default collection
* Change RULE_LOADER default
* Rename to _toml_load_cache
* Use rglob magic
* Typo should've been a string
* Remove no longer needed glob import
* Fix pycharm import bad ordering
* Restore the detection_rules/schemas imports
* Put more imports back for a smaller diff
* Check cache in _deserialize_toml
* Add multi collection and single collection decorators
* Reorder RuleCollection methods
* Move filter method up
2021-04-05 14:23:37 -06:00

57 lines
1.8 KiB
Python

# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
"""Generic mixin classes."""
from typing import TypeVar, Type
import marshmallow_dataclass
from marshmallow import Schema
from .utils import cached
T = TypeVar('T')
ClassT = TypeVar('ClassT') # bound=dataclass?
def _strip_none_from_dict(obj: T) -> T:
"""Strip none values from a dict recursively."""
if isinstance(obj, dict):
return {key: _strip_none_from_dict(value) for key, value in obj.items() if value is not None}
if isinstance(obj, list):
return [_strip_none_from_dict(o) for o in obj]
if isinstance(obj, tuple):
return tuple(_strip_none_from_dict(list(obj)))
return obj
class MarshmallowDataclassMixin:
"""Mixin class for marshmallow serialization."""
@classmethod
@cached
def __schema(cls: ClassT) -> Schema:
"""Get the marshmallow schema for the data class"""
return marshmallow_dataclass.class_schema(cls)()
def get(self, key: str):
"""Get a key from the query data without raising attribute errors."""
return getattr(self, key, None)
@classmethod
def from_dict(cls: Type[ClassT], obj: dict) -> ClassT:
"""Deserialize and validate a dataclass from a dict using marshmallow."""
schema = cls.__schema()
return schema.load(obj)
def to_dict(self, strip_none_values=True) -> dict:
"""Serialize a dataclass to a dictionary using marshmallow."""
schema = self.__schema()
serialized: dict = schema.dump(self)
if strip_none_values:
serialized = _strip_none_from_dict(serialized)
return serialized