Files
Eric Forte 57bf1546dd [Bug] [DAC] Add filtering to export-rules-from-repo (#5769)
* Add filtering to export-rules-from-repo
2026-03-10 13:03:52 -04:00

225 lines
8.8 KiB
Python

# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
"""Load generic toml formatted files for exceptions and actions."""
from collections.abc import Callable, Iterable, Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any
import pytoml # type: ignore[reportMissingTypeStubs]
from .action import TOMLAction, TOMLActionContents
from .action_connector import TOMLActionConnector, TOMLActionConnectorContents
from .config import parse_rules_config
from .exception import TOMLException, TOMLExceptionContents
from .rule_loader import dict_filter
if TYPE_CHECKING:
from .schemas import definitions
RULES_CONFIG = parse_rules_config()
GenericCollectionTypes = TOMLAction | TOMLActionConnector | TOMLException
GenericCollectionContentTypes = TOMLActionContents | TOMLActionConnectorContents | TOMLExceptionContents
def matches_rule_ids(item: GenericCollectionTypes, rule_ids: set[str]) -> bool:
"""Check if the item is associated with any of the given rule IDs."""
rule_ids_list = getattr(item.contents.metadata, "rule_ids", []) + getattr(item.contents.metadata, "rule_id", [])
return any(rule_id in rule_ids for rule_id in rule_ids_list)
def metadata_filter(**metadata: Any) -> Callable[[GenericCollectionTypes], bool]:
"""Get a filter callback based off item metadata"""
flt = dict_filter(metadata)
def callback(item: GenericCollectionTypes) -> bool:
target_dict = item.contents.metadata.to_dict()
return flt(target_dict)
return callback
class GenericCollection:
"""Generic collection for action and exception objects."""
items: list[GenericCollectionTypes]
__default = None
def __init__(self, items: list[GenericCollectionTypes] | None = None) -> None:
self.id_map: dict[definitions.UUIDString, GenericCollectionTypes] = {}
self.file_map: dict[Path, GenericCollectionTypes] = {}
self.name_map: dict[definitions.RuleName, GenericCollectionTypes] = {}
self.items: list[GenericCollectionTypes] = []
self.errors: dict[Path, Exception] = {}
self.frozen = False
self._toml_load_cache: dict[Path, dict[str, Any]] = {}
for item in items or []:
self.add_item(item)
def __len__(self) -> int:
"""Get the total amount of exceptions in the collection."""
return len(self.items)
def __iter__(self) -> Iterator[GenericCollectionTypes]:
"""Iterate over all items in the collection."""
return iter(self.items)
def __contains__(self, item: GenericCollectionTypes) -> bool:
"""Check if an item is in the map by comparing IDs."""
return item.id in self.id_map # type: ignore[reportAttributeAccessIssue]
def filter(self, cb: Callable[[TOMLException], bool]) -> "GenericCollection":
"""Retrieve a filtered collection of items."""
filtered_collection = GenericCollection()
for item in filter(cb, self.items): # type: ignore[reportCallIssue]
filtered_collection.add_item(item)
return filtered_collection
def items_matching(
self,
contents_type: type[GenericCollectionContentTypes],
rule_ids: set[str],
) -> list[GenericCollectionTypes]:
"""Return items whose contents are of the given type and match any of the rule IDs."""
return [d for d in self.items if isinstance(d.contents, contents_type) and matches_rule_ids(d, rule_ids)]
@staticmethod
def deserialize_toml_string(contents: bytes | str) -> dict[str, Any]:
"""Deserialize a TOML string into a dictionary."""
return pytoml.loads(contents) # type: ignore[reportUnknownVariableType]
def _load_toml_file(self, path: Path) -> dict[str, Any]:
"""Load a TOML file into a dictionary."""
if path in self._toml_load_cache:
return self._toml_load_cache[path]
# use pytoml instead of toml because of annoying bugs
# https://github.com/uiri/toml/issues/152
# might also be worth looking at https://github.com/sdispater/tomlkit
with path.open("r", encoding="utf-8") as f:
toml_dict = self.deserialize_toml_string(f.read())
self._toml_load_cache[path] = toml_dict
return toml_dict
def _get_paths(self, directory: Path, recursive: bool = True) -> list[Path]:
"""Get all TOML files in a directory."""
return sorted(directory.rglob("*.toml") if recursive else directory.glob("*.toml"))
def _assert_new(self, item: GenericCollectionTypes) -> None:
"""Assert that the item is new and can be added to the collection."""
file_map = self.file_map
name_map = self.name_map
if self.frozen:
raise ValueError(f"Unable to add item {item.name} to a frozen collection")
if item.name in name_map:
raise ValueError(f"Rule Name {item.name} collides with {name_map[item.name].name}")
if item.path is not None:
item_path = item.path.resolve()
if item_path in file_map:
raise ValueError(f"Item file {item_path} already loaded")
file_map[item_path] = item
def add_item(self, item: GenericCollectionTypes) -> None:
"""Add a new item to the collection."""
self._assert_new(item)
self.name_map[item.name] = item
self.items.append(item)
def load_dict(self, obj: dict[str, Any], path: Path | None = None) -> GenericCollectionTypes:
"""Load a dictionary into the collection."""
if "exceptions" in obj:
contents = TOMLExceptionContents.from_dict(obj)
item = TOMLException(path=path, contents=contents)
elif "actions" in obj:
contents = TOMLActionContents.from_dict(obj)
if not path:
raise ValueError("No path value provided")
item = TOMLAction(path=path, contents=contents)
elif "action_connectors" in obj:
if not path:
raise ValueError("No path value provided")
contents = TOMLActionConnectorContents.from_dict(obj)
item = TOMLActionConnector(path=path, contents=contents)
else:
raise ValueError("Invalid object type")
self.add_item(item)
return item
def load_file(self, path: Path) -> GenericCollectionTypes:
"""Load a single file into the collection."""
try:
path = path.resolve()
# use the default generic loader as a cache.
# if it already loaded the item, then we can just use it from that
if self.__default and self is not self.__default and path in self.__default.file_map:
item = self.__default.file_map[path]
self.add_item(item)
return item
obj = self._load_toml_file(path)
return self.load_dict(obj, path=path)
except Exception:
print(f"Error loading item in {path}")
raise
def load_files(self, paths: Iterable[Path]) -> None:
"""Load multiple files into the collection."""
for path in paths:
_ = self.load_file(path)
def load_directory(
self,
directory: Path,
recursive: bool = True,
toml_filter: Callable[[dict[str, Any]], bool] | None = None,
) -> None:
"""Load all TOML files in a directory."""
paths = self._get_paths(directory, recursive=recursive)
if toml_filter is not None:
paths = [path for path in paths if toml_filter(self._load_toml_file(path))]
self.load_files(paths)
def load_directories(
self,
directories: Iterable[Path],
recursive: bool = True,
toml_filter: Callable[[dict[str, Any]], bool] | None = None,
) -> None:
"""Load all TOML files in multiple directories."""
for path in directories:
self.load_directory(path, recursive=recursive, toml_filter=toml_filter)
def freeze(self) -> None:
"""Freeze the generic collection and make it immutable going forward."""
self.frozen = True
@classmethod
def default(cls) -> "GenericCollection":
"""Return the default item collection, which retrieves from default config location."""
if cls.__default is None:
collection = GenericCollection()
if RULES_CONFIG.exception_dir:
collection.load_directory(RULES_CONFIG.exception_dir)
if RULES_CONFIG.action_dir:
collection.load_directory(RULES_CONFIG.action_dir)
if RULES_CONFIG.action_connector_dir:
collection.load_directory(RULES_CONFIG.action_connector_dir)
collection.freeze()
cls.__default = collection
return cls.__default