From b8116a5b7703555b40351a59c8bb72a23be87f16 Mon Sep 17 00:00:00 2001 From: Justin Ibarra Date: Mon, 8 Feb 2021 21:35:44 -0900 Subject: [PATCH] Add GitHub PR rule loader (#670) * add load_gh_pr_rules function * add dev package-stats command * add dev search-rule-prs command, which extends the same functionality in rule-search to rules in PR --- detection_rules/devtools.py | 98 +++++++++++++++++++++++++++++++-- detection_rules/main.py | 8 ++- detection_rules/packaging.py | 3 +- detection_rules/rule_loader.py | 63 ++++++++++++++++++++- detection_rules/schemas/base.py | 3 +- 5 files changed, 163 insertions(+), 12 deletions(-) diff --git a/detection_rules/devtools.py b/detection_rules/devtools.py index 20a39ce55..0edaba793 100644 --- a/detection_rules/devtools.py +++ b/detection_rules/devtools.py @@ -40,14 +40,23 @@ def dev_group(): @click.argument('config-file', type=click.Path(exists=True, dir_okay=False), required=False, default=PACKAGE_FILE) @click.option('--update-version-lock', '-u', is_flag=True, help='Save version.lock.json file with updated rule versions in the package') -def build_release(config_file, update_version_lock): +def build_release(config_file, update_version_lock, release=None, verbose=True): """Assemble all the rules into Kibana-ready release files.""" config = load_dump(config_file)['package'] - click.echo('[+] Building package {}'.format(config.get('name'))) - package = Package.from_config(config, update_version_lock=update_version_lock, verbose=True) - package.save() - package.get_package_hash(verbose=True) - click.echo('- {} rules included'.format(len(package.rules))) + if release is not None: + config['release'] = release + + if verbose: + click.echo('[+] Building package {}'.format(config.get('name'))) + + package = Package.from_config(config, update_version_lock=update_version_lock, verbose=verbose) + package.save(verbose=verbose) + + if verbose: + package.get_package_hash(verbose=True) + click.echo(f'- {len(package.rules)} rules included') + + return package @dev_group.command('update-lock-versions') @@ -210,6 +219,83 @@ def license_check(ctx): ctx.exit(int(failed)) +@dev_group.command('package-stats') +@click.option('--token', '-t', help='GitHub token to search API authenticated (may exceed threshold without auth)') +@click.option('--threads', default=50, help='Number of threads to download rules from GitHub') +@click.pass_context +def package_stats(ctx, token, threads): + """Get statistics for current rule package.""" + current_package: Package = ctx.invoke(build_release, verbose=False, release=None) + release = f'v{current_package.name}.0' + new, modified, errors = rule_loader.load_github_pr_rules(labels=[release], token=token, threads=threads) + + click.echo(f'Total rules as of {release} package: {len(current_package.rules)}') + click.echo(f'New rules: {len(current_package.new_rules_ids)}') + click.echo(f'Modified rules: {len(current_package.changed_rule_ids)}') + click.echo(f'Deprecated rules: {len(current_package.removed_rule_ids)}') + + click.echo('\n-----\n') + click.echo('Rules in active PRs for current package: ') + click.echo(f'New rules: {len(new)}') + click.echo(f'Modified rules: {len(modified)}') + + +@dev_group.command('search-rule-prs') +@click.argument('query', required=False) +@click.option('--no-loop', '-n', is_flag=True, help='Run once with no loop') +@click.option('--columns', '-c', multiple=True, help='Specify columns to add the table') +@click.option('--language', type=click.Choice(["eql", "kql"]), default="kql") +@click.option('--token', '-t', help='GitHub token to search API authenticated (may exceed threshold without auth)') +@click.option('--threads', default=50, help='Number of threads to download rules from GitHub') +@click.pass_context +def search_rule_prs(ctx, no_loop, query, columns, language, token, threads): + """Use KQL or EQL to find matching rules from active GitHub PRs.""" + from uuid import uuid4 + from .main import search_rules + + all_rules = {} + new, modified, errors = rule_loader.load_github_pr_rules(token=token, threads=threads) + + def add_github_meta(this_rule, status, original_rule_id=None): + pr = this_rule.gh_pr + rule.metadata['status'] = status + rule.metadata['github'] = { + 'base': pr.base.label, + 'comments': [c.body for c in pr.get_comments()], + 'commits': pr.commits, + 'created_at': str(pr.created_at), + 'head': pr.head.label, + 'is_draft': pr.draft, + 'labels': [lbl.name for lbl in pr.get_labels()], + 'last_modified': str(pr.last_modified), + 'title': pr.title, + 'url': pr.html_url, + 'user': pr.user.login + } + + if original_rule_id: + rule.metadata['original_rule_id'] = original_rule_id + rule.contents['rule_id'] = str(uuid4()) + + rule_path = f'pr-{pr.number}-{rule.path}' + all_rules[rule_path] = rule.rule_format() + + for rule_id, rule in new.items(): + add_github_meta(rule, 'new') + + for rule_id, rules in modified.items(): + for rule in rules: + add_github_meta(rule, 'modified', rule_id) + + loop = not no_loop + ctx.invoke(search_rules, query=query, columns=columns, language=language, rules=all_rules, pager=loop) + + while loop: + query = click.prompt(f'Search loop - enter new {language} query or ctrl-z to exit') + columns = click.prompt('columns', default=','.join(columns)).split(',') + ctx.invoke(search_rules, query=query, columns=columns, language=language, rules=all_rules, pager=True) + + @dev_group.group('test') def test_group(): """Commands for testing against stack resources.""" diff --git a/detection_rules/main.py b/detection_rules/main.py index b31f6aca2..d799ff873 100644 --- a/detection_rules/main.py +++ b/detection_rules/main.py @@ -9,6 +9,7 @@ import os import re import time from pathlib import Path +from typing import Dict import click import jsonschema @@ -250,7 +251,7 @@ def validate_all(fail): @click.option('--columns', '-c', multiple=True, help='Specify columns to add the table') @click.option('--language', type=click.Choice(["eql", "kql"]), default="kql") @click.option('--count', is_flag=True, help='Return a count rather than table') -def search_rules(query, columns, language, count, verbose=True): +def search_rules(query, columns, language, count, verbose=True, rules: Dict[str, dict] = None, pager=False): """Use KQL or EQL to find matching rules.""" from kql import get_evaluator from eql.table import Table @@ -259,8 +260,9 @@ def search_rules(query, columns, language, count, verbose=True): from eql.pipes import CountPipe flattened_rules = [] + rules = rules or rule_loader.load_rule_files(verbose=verbose) - for file_name, rule_doc in rule_loader.load_rule_files(verbose=verbose).items(): + for file_name, rule_doc in rules.items(): flat = {"file": os.path.relpath(file_name)} flat.update(rule_doc) flat.update(rule_doc["metadata"]) @@ -309,7 +311,7 @@ def search_rules(query, columns, language, count, verbose=True): table = Table.from_list(columns, filtered) if verbose: - click.echo(table) + click.echo_via_pager(table) if pager else click.echo(table) return filtered diff --git a/detection_rules/packaging.py b/detection_rules/packaging.py index 78542a3d4..41b9985bf 100644 --- a/detection_rules/packaging.py +++ b/detection_rules/packaging.py @@ -311,7 +311,8 @@ class Package(object): click.echo(f' - {len(all_rules) - len(rules)} rules excluded from package') update = config.pop('update', {}) - package = cls(rules, deprecated_rules=deprecated_rules, update_version_lock=update_version_lock, **config) + package = cls(rules, deprecated_rules=deprecated_rules, update_version_lock=update_version_lock, + verbose=verbose, **config) # Allow for some fields to be overwritten if update.get('data', {}): diff --git a/detection_rules/rule_loader.py b/detection_rules/rule_loader.py index f04fa5c7f..8bcb64710 100644 --- a/detection_rules/rule_loader.py +++ b/detection_rules/rule_loader.py @@ -9,6 +9,7 @@ import io import os import re from collections import OrderedDict +from typing import Dict, List import click import pytoml @@ -76,7 +77,7 @@ def load_rules(file_lookup=None, verbose=True, error=True): file_lookup = file_lookup or load_rule_files(verbose=verbose) failed = False - rules = [] # type: list[Rule] + rules: List[Rule] = [] errors = [] queries = [] query_check_index = [] @@ -132,6 +133,65 @@ def load_rules(file_lookup=None, verbose=True, error=True): return OrderedDict([(rule.id, rule) for rule in sorted(rules, key=lambda r: r.name)]) +@cached +def load_github_pr_rules(labels: list = None, repo: str = 'elastic/detection-rules', token=None, threads=50, + verbose=True): + """Load all rules active as a GitHub PR.""" + import requests + import pytoml + from multiprocessing.pool import ThreadPool + from pathlib import Path + from .misc import GithubClient + + github = GithubClient(token=token) + repo = github.client.get_repo(repo) + labels = set(labels or []) + open_prs = [r for r in repo.get_pulls() if not labels.difference(set(list(lbl.name for lbl in r.get_labels())))] + + new_rules: List[Rule] = [] + modified_rules: List[Rule] = [] + errors: Dict[str, list] = {} + + existing_rules = load_rules(verbose=False) + pr_rules = [] + + if verbose: + click.echo('Downloading rules from GitHub PRs') + + def download_worker(pr_info): + pull, rule_file = pr_info + response = requests.get(rule_file.raw_url) + try: + raw_rule = pytoml.loads(response.text) + rule = Rule(rule_file.filename, raw_rule) + rule.gh_pr = pull + + if rule.id in existing_rules: + modified_rules.append(rule) + else: + new_rules.append(rule) + + except Exception as e: + errors.setdefault(Path(rule_file.filename).name, []).append(str(e)) + + for pr in open_prs: + pr_rules.extend([(pr, f) for f in pr.get_files() + if f.filename.startswith('rules/') and f.filename.endswith('.toml')]) + + pool = ThreadPool(processes=threads) + pool.map(download_worker, pr_rules) + pool.close() + pool.join() + + new = OrderedDict([(rule.id, rule) for rule in sorted(new_rules, key=lambda r: r.name)]) + modified = OrderedDict() + + for modified_rule in sorted(modified_rules, key=lambda r: r.name): + modified.setdefault(modified_rule.id, []).append(modified_rule) + + return new, modified, errors + + @cached def get_rule(rule_id=None, rule_name=None, file_name=None, verbose=True): """Get a rule based on its id.""" @@ -195,6 +255,7 @@ __all__ = ( "load_rule_files", "load_rules", "load_rule_files", + "load_github_pr_rules", "get_file_name", "get_production_rules", "get_rule", diff --git a/detection_rules/schemas/base.py b/detection_rules/schemas/base.py index 57598fc12..d44840f38 100644 --- a/detection_rules/schemas/base.py +++ b/detection_rules/schemas/base.py @@ -74,14 +74,15 @@ class TomlMetadata(GenericSchema): # rule validated against each ecs schema contained beats_version = jsl.StringField(pattern=VERSION_PATTERN, required=False) + comments = jsl.StringField(required=False) ecs_versions = jsl.ArrayField(jsl.StringField(pattern=VERSION_PATTERN, required=True), required=False) maturity = jsl.StringField(enum=MATURITY_LEVELS, default='development', required=True) os_type_list = jsl.ArrayField(jsl.StringField(enum=OS_OPTIONS), required=False) + query_schema_validation = jsl.BooleanField(required=False) related_endpoint_rules = jsl.ArrayField(jsl.ArrayField(jsl.StringField(), min_items=2, max_items=2), required=False) updated_date = jsl.StringField(required=True, pattern=DATE_PATTERN, default=time.strftime('%Y/%m/%d')) - query_schema_validation = jsl.BooleanField(required=False) class BaseApiSchema(GenericSchema):