Add GitHub PR rule loader (#670)
* add load_gh_pr_rules function * add dev package-stats command * add dev search-rule-prs command, which extends the same functionality in rule-search to rules in PR
This commit is contained in:
@@ -40,14 +40,23 @@ def dev_group():
|
||||
@click.argument('config-file', type=click.Path(exists=True, dir_okay=False), required=False, default=PACKAGE_FILE)
|
||||
@click.option('--update-version-lock', '-u', is_flag=True,
|
||||
help='Save version.lock.json file with updated rule versions in the package')
|
||||
def build_release(config_file, update_version_lock):
|
||||
def build_release(config_file, update_version_lock, release=None, verbose=True):
|
||||
"""Assemble all the rules into Kibana-ready release files."""
|
||||
config = load_dump(config_file)['package']
|
||||
click.echo('[+] Building package {}'.format(config.get('name')))
|
||||
package = Package.from_config(config, update_version_lock=update_version_lock, verbose=True)
|
||||
package.save()
|
||||
package.get_package_hash(verbose=True)
|
||||
click.echo('- {} rules included'.format(len(package.rules)))
|
||||
if release is not None:
|
||||
config['release'] = release
|
||||
|
||||
if verbose:
|
||||
click.echo('[+] Building package {}'.format(config.get('name')))
|
||||
|
||||
package = Package.from_config(config, update_version_lock=update_version_lock, verbose=verbose)
|
||||
package.save(verbose=verbose)
|
||||
|
||||
if verbose:
|
||||
package.get_package_hash(verbose=True)
|
||||
click.echo(f'- {len(package.rules)} rules included')
|
||||
|
||||
return package
|
||||
|
||||
|
||||
@dev_group.command('update-lock-versions')
|
||||
@@ -210,6 +219,83 @@ def license_check(ctx):
|
||||
ctx.exit(int(failed))
|
||||
|
||||
|
||||
@dev_group.command('package-stats')
|
||||
@click.option('--token', '-t', help='GitHub token to search API authenticated (may exceed threshold without auth)')
|
||||
@click.option('--threads', default=50, help='Number of threads to download rules from GitHub')
|
||||
@click.pass_context
|
||||
def package_stats(ctx, token, threads):
|
||||
"""Get statistics for current rule package."""
|
||||
current_package: Package = ctx.invoke(build_release, verbose=False, release=None)
|
||||
release = f'v{current_package.name}.0'
|
||||
new, modified, errors = rule_loader.load_github_pr_rules(labels=[release], token=token, threads=threads)
|
||||
|
||||
click.echo(f'Total rules as of {release} package: {len(current_package.rules)}')
|
||||
click.echo(f'New rules: {len(current_package.new_rules_ids)}')
|
||||
click.echo(f'Modified rules: {len(current_package.changed_rule_ids)}')
|
||||
click.echo(f'Deprecated rules: {len(current_package.removed_rule_ids)}')
|
||||
|
||||
click.echo('\n-----\n')
|
||||
click.echo('Rules in active PRs for current package: ')
|
||||
click.echo(f'New rules: {len(new)}')
|
||||
click.echo(f'Modified rules: {len(modified)}')
|
||||
|
||||
|
||||
@dev_group.command('search-rule-prs')
|
||||
@click.argument('query', required=False)
|
||||
@click.option('--no-loop', '-n', is_flag=True, help='Run once with no loop')
|
||||
@click.option('--columns', '-c', multiple=True, help='Specify columns to add the table')
|
||||
@click.option('--language', type=click.Choice(["eql", "kql"]), default="kql")
|
||||
@click.option('--token', '-t', help='GitHub token to search API authenticated (may exceed threshold without auth)')
|
||||
@click.option('--threads', default=50, help='Number of threads to download rules from GitHub')
|
||||
@click.pass_context
|
||||
def search_rule_prs(ctx, no_loop, query, columns, language, token, threads):
|
||||
"""Use KQL or EQL to find matching rules from active GitHub PRs."""
|
||||
from uuid import uuid4
|
||||
from .main import search_rules
|
||||
|
||||
all_rules = {}
|
||||
new, modified, errors = rule_loader.load_github_pr_rules(token=token, threads=threads)
|
||||
|
||||
def add_github_meta(this_rule, status, original_rule_id=None):
|
||||
pr = this_rule.gh_pr
|
||||
rule.metadata['status'] = status
|
||||
rule.metadata['github'] = {
|
||||
'base': pr.base.label,
|
||||
'comments': [c.body for c in pr.get_comments()],
|
||||
'commits': pr.commits,
|
||||
'created_at': str(pr.created_at),
|
||||
'head': pr.head.label,
|
||||
'is_draft': pr.draft,
|
||||
'labels': [lbl.name for lbl in pr.get_labels()],
|
||||
'last_modified': str(pr.last_modified),
|
||||
'title': pr.title,
|
||||
'url': pr.html_url,
|
||||
'user': pr.user.login
|
||||
}
|
||||
|
||||
if original_rule_id:
|
||||
rule.metadata['original_rule_id'] = original_rule_id
|
||||
rule.contents['rule_id'] = str(uuid4())
|
||||
|
||||
rule_path = f'pr-{pr.number}-{rule.path}'
|
||||
all_rules[rule_path] = rule.rule_format()
|
||||
|
||||
for rule_id, rule in new.items():
|
||||
add_github_meta(rule, 'new')
|
||||
|
||||
for rule_id, rules in modified.items():
|
||||
for rule in rules:
|
||||
add_github_meta(rule, 'modified', rule_id)
|
||||
|
||||
loop = not no_loop
|
||||
ctx.invoke(search_rules, query=query, columns=columns, language=language, rules=all_rules, pager=loop)
|
||||
|
||||
while loop:
|
||||
query = click.prompt(f'Search loop - enter new {language} query or ctrl-z to exit')
|
||||
columns = click.prompt('columns', default=','.join(columns)).split(',')
|
||||
ctx.invoke(search_rules, query=query, columns=columns, language=language, rules=all_rules, pager=True)
|
||||
|
||||
|
||||
@dev_group.group('test')
|
||||
def test_group():
|
||||
"""Commands for testing against stack resources."""
|
||||
|
||||
@@ -9,6 +9,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import click
|
||||
import jsonschema
|
||||
@@ -250,7 +251,7 @@ def validate_all(fail):
|
||||
@click.option('--columns', '-c', multiple=True, help='Specify columns to add the table')
|
||||
@click.option('--language', type=click.Choice(["eql", "kql"]), default="kql")
|
||||
@click.option('--count', is_flag=True, help='Return a count rather than table')
|
||||
def search_rules(query, columns, language, count, verbose=True):
|
||||
def search_rules(query, columns, language, count, verbose=True, rules: Dict[str, dict] = None, pager=False):
|
||||
"""Use KQL or EQL to find matching rules."""
|
||||
from kql import get_evaluator
|
||||
from eql.table import Table
|
||||
@@ -259,8 +260,9 @@ def search_rules(query, columns, language, count, verbose=True):
|
||||
from eql.pipes import CountPipe
|
||||
|
||||
flattened_rules = []
|
||||
rules = rules or rule_loader.load_rule_files(verbose=verbose)
|
||||
|
||||
for file_name, rule_doc in rule_loader.load_rule_files(verbose=verbose).items():
|
||||
for file_name, rule_doc in rules.items():
|
||||
flat = {"file": os.path.relpath(file_name)}
|
||||
flat.update(rule_doc)
|
||||
flat.update(rule_doc["metadata"])
|
||||
@@ -309,7 +311,7 @@ def search_rules(query, columns, language, count, verbose=True):
|
||||
table = Table.from_list(columns, filtered)
|
||||
|
||||
if verbose:
|
||||
click.echo(table)
|
||||
click.echo_via_pager(table) if pager else click.echo(table)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
@@ -311,7 +311,8 @@ class Package(object):
|
||||
click.echo(f' - {len(all_rules) - len(rules)} rules excluded from package')
|
||||
|
||||
update = config.pop('update', {})
|
||||
package = cls(rules, deprecated_rules=deprecated_rules, update_version_lock=update_version_lock, **config)
|
||||
package = cls(rules, deprecated_rules=deprecated_rules, update_version_lock=update_version_lock,
|
||||
verbose=verbose, **config)
|
||||
|
||||
# Allow for some fields to be overwritten
|
||||
if update.get('data', {}):
|
||||
|
||||
@@ -9,6 +9,7 @@ import io
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List
|
||||
|
||||
import click
|
||||
import pytoml
|
||||
@@ -76,7 +77,7 @@ def load_rules(file_lookup=None, verbose=True, error=True):
|
||||
file_lookup = file_lookup or load_rule_files(verbose=verbose)
|
||||
|
||||
failed = False
|
||||
rules = [] # type: list[Rule]
|
||||
rules: List[Rule] = []
|
||||
errors = []
|
||||
queries = []
|
||||
query_check_index = []
|
||||
@@ -132,6 +133,65 @@ def load_rules(file_lookup=None, verbose=True, error=True):
|
||||
return OrderedDict([(rule.id, rule) for rule in sorted(rules, key=lambda r: r.name)])
|
||||
|
||||
|
||||
@cached
|
||||
def load_github_pr_rules(labels: list = None, repo: str = 'elastic/detection-rules', token=None, threads=50,
|
||||
verbose=True):
|
||||
"""Load all rules active as a GitHub PR."""
|
||||
import requests
|
||||
import pytoml
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from .misc import GithubClient
|
||||
|
||||
github = GithubClient(token=token)
|
||||
repo = github.client.get_repo(repo)
|
||||
labels = set(labels or [])
|
||||
open_prs = [r for r in repo.get_pulls() if not labels.difference(set(list(lbl.name for lbl in r.get_labels())))]
|
||||
|
||||
new_rules: List[Rule] = []
|
||||
modified_rules: List[Rule] = []
|
||||
errors: Dict[str, list] = {}
|
||||
|
||||
existing_rules = load_rules(verbose=False)
|
||||
pr_rules = []
|
||||
|
||||
if verbose:
|
||||
click.echo('Downloading rules from GitHub PRs')
|
||||
|
||||
def download_worker(pr_info):
|
||||
pull, rule_file = pr_info
|
||||
response = requests.get(rule_file.raw_url)
|
||||
try:
|
||||
raw_rule = pytoml.loads(response.text)
|
||||
rule = Rule(rule_file.filename, raw_rule)
|
||||
rule.gh_pr = pull
|
||||
|
||||
if rule.id in existing_rules:
|
||||
modified_rules.append(rule)
|
||||
else:
|
||||
new_rules.append(rule)
|
||||
|
||||
except Exception as e:
|
||||
errors.setdefault(Path(rule_file.filename).name, []).append(str(e))
|
||||
|
||||
for pr in open_prs:
|
||||
pr_rules.extend([(pr, f) for f in pr.get_files()
|
||||
if f.filename.startswith('rules/') and f.filename.endswith('.toml')])
|
||||
|
||||
pool = ThreadPool(processes=threads)
|
||||
pool.map(download_worker, pr_rules)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
new = OrderedDict([(rule.id, rule) for rule in sorted(new_rules, key=lambda r: r.name)])
|
||||
modified = OrderedDict()
|
||||
|
||||
for modified_rule in sorted(modified_rules, key=lambda r: r.name):
|
||||
modified.setdefault(modified_rule.id, []).append(modified_rule)
|
||||
|
||||
return new, modified, errors
|
||||
|
||||
|
||||
@cached
|
||||
def get_rule(rule_id=None, rule_name=None, file_name=None, verbose=True):
|
||||
"""Get a rule based on its id."""
|
||||
@@ -195,6 +255,7 @@ __all__ = (
|
||||
"load_rule_files",
|
||||
"load_rules",
|
||||
"load_rule_files",
|
||||
"load_github_pr_rules",
|
||||
"get_file_name",
|
||||
"get_production_rules",
|
||||
"get_rule",
|
||||
|
||||
@@ -74,14 +74,15 @@ class TomlMetadata(GenericSchema):
|
||||
|
||||
# rule validated against each ecs schema contained
|
||||
beats_version = jsl.StringField(pattern=VERSION_PATTERN, required=False)
|
||||
comments = jsl.StringField(required=False)
|
||||
ecs_versions = jsl.ArrayField(jsl.StringField(pattern=VERSION_PATTERN, required=True), required=False)
|
||||
maturity = jsl.StringField(enum=MATURITY_LEVELS, default='development', required=True)
|
||||
|
||||
os_type_list = jsl.ArrayField(jsl.StringField(enum=OS_OPTIONS), required=False)
|
||||
query_schema_validation = jsl.BooleanField(required=False)
|
||||
related_endpoint_rules = jsl.ArrayField(jsl.ArrayField(jsl.StringField(), min_items=2, max_items=2),
|
||||
required=False)
|
||||
updated_date = jsl.StringField(required=True, pattern=DATE_PATTERN, default=time.strftime('%Y/%m/%d'))
|
||||
query_schema_validation = jsl.BooleanField(required=False)
|
||||
|
||||
|
||||
class BaseApiSchema(GenericSchema):
|
||||
|
||||
Reference in New Issue
Block a user