diff --git a/detection_rules/devtools.py b/detection_rules/devtools.py index 425ea677a..d26d07d70 100644 --- a/detection_rules/devtools.py +++ b/detection_rules/devtools.py @@ -279,7 +279,7 @@ def kibana_commit(ctx, local_repo: str, github_repo: str, ssh: bool, kibana_dire branch_name = branch_name or f"detection-rules/{package_name}-{short_commit_hash}" - git("checkout", "-b", branch_name, show_output=True) + git("checkout", "-b", branch_name, print_output=True) git("rm", "-r", kibana_directory) source_dir = os.path.join(release_dir, "rules") @@ -295,7 +295,7 @@ def kibana_commit(ctx, local_repo: str, github_repo: str, ssh: bool, kibana_dire git("add", kibana_directory) git("commit", "--no-verify", "-m", message) - git("status", show_output=True) + git("status", print_output=True) if push: git("push", "origin", branch_name) diff --git a/detection_rules/rule_loader.py b/detection_rules/rule_loader.py index ade4d6d7c..d45c6dc8d 100644 --- a/detection_rules/rule_loader.py +++ b/detection_rules/rule_loader.py @@ -5,14 +5,14 @@ """Load rule metadata transform between rule and api formats.""" import io -import typing from collections import OrderedDict from pathlib import Path -from typing import Dict, List, Iterable, Callable, Optional +from typing import Dict, List, Iterable, Callable, Optional, Union import click import pytoml +from . import utils from .mappings import RtaMappings from .rule import TOMLRule, TOMLRuleContents from .schemas import definitions @@ -113,7 +113,7 @@ class RuleCollection: return filtered_collection @staticmethod - def deserialize_toml_string(contents: typing.Union[bytes, str]) -> dict: + def deserialize_toml_string(contents: Union[bytes, str]) -> dict: return pytoml.loads(contents) def _load_toml_file(self, path: Path) -> dict: @@ -123,7 +123,7 @@ class RuleCollection: # use pytoml instead of toml because of annoying bugs # https://github.com/uiri/toml/issues/152 # might also be worth looking at https://github.com/sdispater/tomlkit - with io.open(str(path.resolve()), "r", encoding="utf-8") as f: + with io.open(path, "r", encoding="utf-8") as f: toml_dict = self.deserialize_toml_string(f.read()) self._toml_load_cache[path] = toml_dict return toml_dict @@ -168,6 +168,21 @@ class RuleCollection: print(f"Error loading rule in {path}") raise + def load_git_branch(self, branch: str): + """Load rules from a Git branch.""" + git = utils.make_git() + rules_dir = DEFAULT_RULES_DIR.relative_to(get_path(".")) + paths = git("ls-files", "--with-tree", branch, rules_dir).splitlines() + + for path in paths: + path = Path(path) + if path.suffix != ".toml": + continue + + contents = git("show", f"{branch}:{path}") + toml_dict = self.deserialize_toml_string(contents) + self.load_dict(toml_dict, path) + def load_files(self, paths: Iterable[Path]): """Load multiple files into the collection.""" for path in paths: @@ -262,7 +277,6 @@ def load_github_pr_rules(labels: list = None, repo: str = 'elastic/detection-rul rta_mappings = RtaMappings() - __all__ = ( "FILE_PATTERN", "DEFAULT_RULES_DIR", diff --git a/detection_rules/utils.py b/detection_rules/utils.py index 3ed9bdd32..9f8dc799a 100644 --- a/detection_rules/utils.py +++ b/detection_rules/utils.py @@ -330,15 +330,20 @@ def make_git(*prefix_args) -> Optional[Callable]: return - def git(*args, show_output=False): + def git(*args, print_output=False): full_args = [git_exe] + prefix_args + [str(arg) for arg in args] - if show_output: - return subprocess.check_output(full_args, encoding="utf-8").rstrip() - return subprocess.check_call(full_args) + if print_output: + return subprocess.check_call(full_args) + return subprocess.check_output(full_args, encoding="utf-8").rstrip() return git +def git(*args, **kwargs): + """Find and run a one-off Git command.""" + return make_git()(*args, **kwargs) + + def add_params(*params): """Add parameters to a click command."""