diff --git a/detection_rules/main.py b/detection_rules/main.py index 8e3abe1b4..06c83a888 100644 --- a/detection_rules/main.py +++ b/detection_rules/main.py @@ -186,7 +186,8 @@ def validate_all(fail): @click.argument('query', required=False) @click.option('--columns', '-c', multiple=True, help='Specify columns to add the table') @click.option('--language', type=click.Choice(["eql", "kql"]), default="kql") -def search_rules(query, columns, language, verbose=True): +@click.option('--count', is_flag=True, help='Return a count rather than table') +def search_rules(query, columns, language, count, verbose=True): """Use KQL or EQL to find matching rules.""" from kql import get_evaluator from eql.table import Table @@ -201,10 +202,21 @@ def search_rules(query, columns, language, verbose=True): flat.update(rule_doc) flat.update(rule_doc["metadata"]) flat.update(rule_doc["rule"]) - attacks = [threat for threat in rule_doc["rule"].get("threat", []) if threat["framework"] == "MITRE ATT&CK"] - techniques = [t["id"] for threat in attacks for t in threat.get("technique", [])] - tactics = [threat["tactic"]["name"] for threat in attacks] - flat.update(techniques=techniques, tactics=tactics, + + tactic_names = [] + technique_ids = [] + subtechnique_ids = [] + + for entry in rule_doc['rule'].get('threat', []): + if entry["framework"] != "MITRE ATT&CK": + continue + + techniques = entry.get('technique', []) + tactic_names.append(entry['tactic']['name']) + technique_ids.extend([t['id'] for t in techniques]) + subtechnique_ids.extend([st['id'] for t in techniques for st in t.get('subtechnique', [])]) + + flat.update(techniques=technique_ids, tactics=tactic_names, subtechniques=subtechnique_ids, unique_fields=Rule.get_unique_query_fields(rule_doc['rule'])) flattened_rules.append(flat) @@ -222,6 +234,10 @@ def search_rules(query, columns, language, verbose=True): if not columns and any(isinstance(pipe, CountPipe) for pipe in parsed.pipes): columns = ["key", "count", "percent"] + if count: + click.echo(f'{len(filtered)} rules') + return filtered + if columns: columns = ",".join(columns).split(",") else: