From cf736046f1f28e12f7c8e9701217dc470dadcc00 Mon Sep 17 00:00:00 2001 From: Ross Wolf <31489089+rw-access@users.noreply.github.com> Date: Thu, 8 Jul 2021 13:44:04 -0600 Subject: [PATCH] Add command to unstage incompatible rules from git (#1317) * Add devtools unstage-incompatible-rules command * Create ephemeral GitChangeEntry for R->D+A * Undo changes to Github job * Fix typo in comment * s/previous_path/original_path --- detection_rules/devtools.py | 85 +++++++++++++++++++++++++++++++++- detection_rules/rule_loader.py | 13 ++++-- 2 files changed, 92 insertions(+), 6 deletions(-) diff --git a/detection_rules/devtools.py b/detection_rules/devtools.py index 8cdd31830..738933208 100644 --- a/detection_rules/devtools.py +++ b/detection_rules/devtools.py @@ -13,11 +13,11 @@ import shutil import subprocess import textwrap import time -import typing from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, List import click +import typing from elasticsearch import Elasticsearch from kibana.connector import Kibana @@ -30,6 +30,7 @@ from .misc import PYTHON_LICENSE, add_client, client_error from .packaging import PACKAGE_FILE, Package, RELEASE_DIR, current_stack_version, manage_versions from .rule import AnyRuleData, BaseRuleData, QueryRuleData, TOMLRule from .rule_loader import RuleCollection, production_filter +from .semver import Version from .utils import dict_hash, get_path, load_dump RULES_DIR = get_path('rules') @@ -74,6 +75,86 @@ def build_release(config_file, update_version_lock, release=None, verbose=True): return package +@dataclasses.dataclass +class GitChangeEntry: + status: str + original_path: Path + new_path: Optional[Path] = None + + @classmethod + def from_line(cls, text: str) -> 'GitChangeEntry': + columns = text.split("\t") + assert 2 <= len(columns) <= 3 + + columns[1:] = [Path(c) for c in columns[1:]] + return cls(*columns) + + @property + def path(self) -> Path: + return self.new_path or self.original_path + + def revert(self, dry_run=False): + """Run a git command to revert this change.""" + def git(*args): + command_line = ["git"] + [str(arg) for arg in args] + click.echo(subprocess.list2cmdline(command_line)) + + if not dry_run: + subprocess.check_call(command_line) + + if self.status.startswith("R"): + # renames are actually Delete (D) and Add (A) + # revert in opposite order + GitChangeEntry("A", self.new_path).revert(dry_run=dry_run) + GitChangeEntry("D", self.original_path).revert(dry_run=dry_run) + return + + # remove the file from the staging area (A|M|D) + git("restore", "--staged", self.original_path) + + def read(self, git_tree="HEAD") -> bytes: + """Read the file from disk or git.""" + if self.status == "D": + # deleted files need to be recovered from git + return subprocess.check_output(["git", "show", f"{git_tree}:{self.path}"]) + + return self.path.read_bytes() + + +@dev_group.command("unstage-incompatible-rules") +@click.option("--target-stack-version", "-t", help="Minimum stack version to filter the staging area", required=True) +@click.option("--dry-run", is_flag=True, help="List the changes that would be made") +def prune_staging_area(target_stack_version: str, dry_run: bool): + """Prune the git staging area to remove changes to incompatible rules.""" + target_stack_version = Version(target_stack_version)[:2] + + # load a structured summary of the diff from git + git_output = subprocess.check_output(["git", "diff", "--name-status", "HEAD"]) + changes = [GitChangeEntry.from_line(line) for line in git_output.decode("utf-8").splitlines()] + + # track which changes need to be reverted because of incompatibilities + reversions: List[GitChangeEntry] = [] + + for change in changes: + # it's a change to a rule file, load it and check the version + if str(change.path.absolute()).startswith(RULES_DIR) and change.path.suffix == ".toml": + # bypass TOML validation in case there were schema changes + dict_contents = RuleCollection.deserialize_toml_string(change.read()) + min_stack_version: Optional[str] = dict_contents.get("metadata", {}).get("min_stack_version") + + if min_stack_version is not None and target_stack_version < Version(min_stack_version)[:2]: + # rule is incompatible, add to the list of reversions to make later + reversions.append(change) + + if len(reversions) == 0: + click.echo("No files restored from staging area") + return + + click.echo(f"Restoring {len(reversions)} changes from the staging area...") + for change in reversions: + change.revert(dry_run=dry_run) + + @dev_group.command('update-lock-versions') @click.argument('rule-ids', nargs=-1, required=False) def update_lock_versions(rule_ids): diff --git a/detection_rules/rule_loader.py b/detection_rules/rule_loader.py index 66596aba7..ade4d6d7c 100644 --- a/detection_rules/rule_loader.py +++ b/detection_rules/rule_loader.py @@ -5,6 +5,7 @@ """Load rule metadata transform between rule and api formats.""" import io +import typing from collections import OrderedDict from pathlib import Path from typing import Dict, List, Iterable, Callable, Optional @@ -111,7 +112,11 @@ class RuleCollection: return filtered_collection - def _deserialize_toml(self, path: Path) -> dict: + @staticmethod + def deserialize_toml_string(contents: typing.Union[bytes, str]) -> dict: + return pytoml.loads(contents) + + def _load_toml_file(self, path: Path) -> dict: if path in self._toml_load_cache: return self._toml_load_cache[path] @@ -119,7 +124,7 @@ class RuleCollection: # https://github.com/uiri/toml/issues/152 # might also be worth looking at https://github.com/sdispater/tomlkit with io.open(str(path.resolve()), "r", encoding="utf-8") as f: - toml_dict = pytoml.load(f) + toml_dict = self.deserialize_toml_string(f.read()) self._toml_load_cache[path] = toml_dict return toml_dict @@ -157,7 +162,7 @@ class RuleCollection: self.add_rule(rule) return rule - obj = self._deserialize_toml(path) + obj = self._load_toml_file(path) return self.load_dict(obj, path=path) except Exception: print(f"Error loading rule in {path}") @@ -171,7 +176,7 @@ class RuleCollection: def load_directory(self, directory: Path, recursive=True, toml_filter: Optional[Callable[[dict], bool]] = None): paths = self._get_paths(directory, recursive=recursive) if toml_filter is not None: - paths = [path for path in paths if toml_filter(self._deserialize_toml(path))] + paths = [path for path in paths if toml_filter(self._load_toml_file(path))] self.load_files(paths)