Add rule loader and dependencies

Co-Authored-By: Justin Ibarra <brokensound77@users.noreply.github.com>
This commit is contained in:
Ross Wolf
2020-06-29 23:17:38 -06:00
parent a0d3b4bd23
commit 3b305d3003
40 changed files with 138653 additions and 0 deletions
+24
View File
@@ -0,0 +1,24 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Detection rules."""
from . import eswrap
from . import main
from . import mappings
from . import misc
from . import rule_formatter
from . import rule_loader
from . import schema
from . import utils
__all__ = (
'eswrap',
'mappings',
"main",
'misc',
'rule_formatter',
'rule_loader',
'schema',
'utils',
)
+28
View File
@@ -0,0 +1,28 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
# coding=utf-8
"""Shell for detection-rules."""
import os
from .main import root
CURR_DIR = os.path.dirname(os.path.abspath(__file__))
CLI_DIR = os.path.dirname(CURR_DIR)
ROOT_DIR = os.path.dirname(CLI_DIR)
BANNER = r"""
█▀▀▄ ▄▄▄ ▄▄▄ ▄▄▄ ▄▄▄ ▄▄▄ ▄▄▄ ▄▄▄ ▄ ▄ █▀▀▄ ▄ ▄ ▄ ▄▄▄ ▄▄▄
█ █ █▄▄ █ █▄▄ █ █ █ █ █ █▀▄ █ █▄▄▀ █ █ █ █▄▄ █▄▄
█▄▄▀ █▄▄ █ █▄▄ █▄▄ █ ▄█▄ █▄█ █ ▀▄█ █ ▀▄ █▄▄█ █▄▄ █▄▄ ▄▄█
"""
def main():
"""CLI entry point."""
print(BANNER)
root(prog_name="detection_rules")
main()
+79
View File
@@ -0,0 +1,79 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Mitre attack info."""
# from: https://raw.githubusercontent.com/mitre/cti/master/enterprise-attack/enterprise-attack.json
from .utils import load_etc_dump
TACTICS_MAP = {
'Initial Access': 'TA0001',
'Persistence': 'TA0003',
'Privilege Escalation': 'TA0004',
'Defense Evasion': 'TA0005',
'Credential Access': 'TA0006',
'Discovery': 'TA0007',
'Lateral Movement': 'TA0008',
'Execution': 'TA0002',
'Collection': 'TA0009',
'Exfiltration': 'TA0011',
'Command and Control': 'TA0010',
'Impact': 'TA0040'
}
TACTICS = list(TACTICS_MAP)
PLATFORMS = ['Windows', 'macOS', 'Linux']
attack = load_etc_dump('attack.json')
technique_lookup = {}
for item in attack["objects"]:
if item["type"] == "attack-pattern" and item["external_references"][0]['source_name'] == 'mitre-attack':
technique_id = item['external_references'][0]['external_id']
technique_lookup[technique_id] = item
matrix = {tactic: [] for tactic in TACTICS}
attack_tm = 'ATT&CK\u2122'
# Enumerate over the techniques and build the matrix back up
for technique_id, technique in sorted(technique_lookup.items(), key=lambda kv: kv[1]['name'].lower()):
for platform in technique['x_mitre_platforms']:
if any(platform.startswith(p) for p in PLATFORMS):
break
else:
continue
for tactic in technique['kill_chain_phases']:
tactic_name = next(t for t in TACTICS if tactic['kill_chain_name'] == 'mitre-attack' and t.lower() == tactic['phase_name'].replace("-", " ")) # noqa: E501
matrix[tactic_name].append(technique_id)
for tactic in matrix:
matrix[tactic].sort(key=lambda tid: technique_lookup[tid]['name'].lower())
TECHNIQUES = {v['name'] for k, v in technique_lookup.items()}
def build_threat_map_entry(tactic: str, *technique_ids: str) -> dict:
"""Build rule threat map from technique IDs."""
url_base = 'https://attack.mitre.org/{type}/{id}/'
tactic_id = TACTICS_MAP[tactic]
entry = {
'framework': 'MITRE ATT&CK',
'technique': [
{
'id': tid,
'name': technique_lookup[tid]['name'],
'reference': url_base.format(type='techniques', id=tid)
} for tid in technique_ids
],
'tactic': {
'id': tactic_id,
'name': tactic,
'reference': url_base.format(type='tactics', id=tactic_id)
}
}
return entry
+160
View File
@@ -0,0 +1,160 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""ECS Schemas management."""
import os
import kql
import requests
import yaml
from .semver import Version
from .utils import unzip, load_etc_dump, save_etc_dump, get_etc_path
def download_latest_beats_schema():
"""Download additional schemas from ecs releases."""
url = 'https://api.github.com/repos/elastic/beats/releases'
releases = requests.get(url)
latest_release = max(releases.json(), key=lambda release: Version(release["tag_name"].lstrip("v")))
print(f"Downloading beats {latest_release['tag_name']}")
response = requests.get(latest_release['zipball_url'])
print(f"Downloaded {len(response.content) / 1024.0 / 1024.0:.2f} MB release.")
fs = {}
parsed = {}
with unzip(response.content) as archive:
base_directory = archive.namelist()[0]
for name in archive.namelist():
if os.path.basename(name) in ("fields.yml", "fields.common.yml", "config.yml"):
contents = archive.read(name)
# chop off the base directory name
key = name[len(base_directory):]
if key.startswith("x-pack"):
key = key[len("x-pack") + 1:]
try:
decoded = yaml.safe_load(contents)
except yaml.YAMLError:
print(f"Error loading {name}")
# create a hierarchical structure
parsed[key] = decoded
branch = fs
directory, base_name = os.path.split(key)
for limb in directory.split(os.path.sep):
branch = branch.setdefault("folders", {}).setdefault(limb, {})
branch.setdefault("files", {})[base_name] = decoded
# remove all non-beat directories
fs = {k: v for k, v in fs.get("folders", {}).items() if k.endswith("beat")}
print(f"Saving etc/beats_schema/{latest_release['tag_name']}.yml")
save_etc_dump(fs, "beats_schemas", latest_release["tag_name"] + ".yml")
def _flatten_schema(schema: list, prefix="") -> list:
if schema is None:
# sometimes we see `fields: null` in the yaml
return []
flattened = []
for s in schema:
if s.get("type") == "group":
flattened.extend(_flatten_schema(s["fields"], prefix=prefix + s["name"] + "."))
elif "fields" in s:
flattened.extend(_flatten_schema(s["fields"], prefix=prefix))
elif "type" in s:
s = s.copy()
s["name"] = prefix + s["name"]
flattened.append(s)
return flattened
def get_field_schema(base_directory, prefix="", include_common=False):
base_directory = base_directory.get("folders", {}).get("_meta", {}).get("files", {})
flattened = []
file_names = ("fields.yml", "fields.common.yml") if include_common else ("fields.yml", )
for name in file_names:
if name in base_directory:
flattened.extend(_flatten_schema(base_directory[name], prefix=prefix))
return flattened
def get_beats_schema(schema: dict, beat: str, module: str, *datasets: str):
if beat not in schema:
raise KeyError(f"Unknown beats module {beat}")
flattened = []
beat_dir = schema[beat]
flattened.extend(get_field_schema(beat_dir, include_common=True))
module_dir = beat_dir.get("folders", {}).get("module", {}).get("folders", {}).get(module, {})
flattened.extend(get_field_schema(module_dir, include_common=True))
# if we only have a module then we'll work with what we got
if not datasets:
datasets = [d for d in module_dir.get("folders", {}) if not d.startswith("_")]
for dataset in datasets:
# replace aws.s3 -> s3
if dataset.startswith(module + "."):
dataset = dataset[len(module) + 1:]
dataset_dir = module_dir.get("folders", {}).get(dataset, {})
flattened.extend(get_field_schema(dataset_dir, prefix=module + ".", include_common=True))
return {field["name"]: field for field in sorted(flattened, key=lambda f: f["name"])}
SCHEMA = None
def read_beats_schema():
global SCHEMA
if SCHEMA is None:
beats_schemas = os.listdir(get_etc_path("beats_schemas"))
latest = max(beats_schemas, key=lambda b: Version(b.lstrip("v")))
SCHEMA = load_etc_dump("beats_schemas", latest)
return SCHEMA
def get_schema_for_query(tree: kql.ast, beats: list) -> dict:
filtered = {}
modules = set()
datasets = set()
# extract out event.module and event.dataset from the query's AST
for node in tree:
if isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.module"):
modules.update(child.value for child in node.value if isinstance(child, kql.ast.String))
if isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.dataset"):
datasets.update(child.value for child in node.value if isinstance(child, kql.ast.String))
beats_schema = read_beats_schema()
for beat in beats:
# if no modules are specified then grab them all
# all_modules = list(beats_schema.get(beat, {}).get("folders", {}).get("module", {}).get("folders", {}))
# beat_modules = modules or all_modules
for module in modules:
filtered.update(get_beats_schema(beats_schema, beat, module, *datasets))
return filtered
+233
View File
@@ -0,0 +1,233 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""ECS Schemas management."""
import copy
import glob
import os
import shutil
import requests
import yaml
from .semver import Version
from .utils import unzip, load_etc_dump, get_etc_path, cached
ECS_SCHEMAS_DIR = get_etc_path("ecs_schemas")
def add_field(schema, name, info):
"""Nest a dotted field within a dictionary."""
if "." not in name:
schema[name] = info
return
top, remaining = name.split(".", 1)
if not isinstance(schema.get(top), dict):
schema[top] = {}
add_field(schema, remaining, info)
def nest_from_dot(dots, value):
"""Nest a dotted field and set the inner most value."""
fields = dots.split('.')
if not fields:
return {}
nested = {fields.pop(): value}
for field in reversed(fields):
nested = {field: nested}
return nested
def _recursive_merge(existing, new, depth=0):
"""Return an existing dict merged into a new one."""
for key, value in existing.items():
if isinstance(value, dict):
if depth == 0:
new = copy.deepcopy(new)
node = new.setdefault(key, {})
_recursive_merge(value, node, depth + 1)
else:
new[key] = value
return new
def get_schema_files():
"""Get schema files from ecs directory."""
return glob.glob(os.path.join(ECS_SCHEMAS_DIR, '*', '*.yml'), recursive=True)
def get_schema_map():
"""Get local schema files by version."""
schema_map = {}
for file_name in get_schema_files():
path, name = os.path.split(file_name)
name = os.path.splitext(name)[0]
version = os.path.basename(path)
schema_map.setdefault(version, {})[name] = file_name
return schema_map
@cached
def get_schemas():
"""Get local schemas."""
schema_map = get_schema_map()
for version, values in schema_map.items():
for name, file_name in values.items():
with open(file_name, 'r') as f:
schema_map[version][name] = yaml.safe_load(f)
return schema_map
def get_max_version(include_master=False):
"""Get maximum available schema version."""
versions = get_schema_map().keys()
if include_master and any([v.startswith('master') for v in versions]):
return glob.glob(os.path.join(ECS_SCHEMAS_DIR, 'master*'))[0]
return str(max([Version(v) for v in versions if not v.startswith('master')]))
@cached
def get_schema(version=None, name='ecs_flat'):
"""Get schema by version."""
return get_schemas()[version][name]
@cached
def get_eql_schema(version=None, index_patterns=None):
"""Return schema in expected format for eql."""
schema = get_schema(version, name='ecs_flat')
str_types = ('text', 'ip', 'keyword', 'date', 'object', 'geo_point')
num_types = ('float', 'integer', 'long')
schema = schema.copy()
def convert_type(t):
return 'string' if t in str_types else 'number' if t in num_types else 'boolean'
converted = {}
for field, schema_info in schema.items():
field_type = schema_info.get('type', '')
add_field(converted, field, convert_type(field_type))
if index_patterns:
for index_name in index_patterns:
for k, v in flatten(get_index_schema(index_name)).items():
add_field(converted, k, convert_type(v))
return converted
def flatten(schema):
flattened = {}
for k, v in schema.items():
if isinstance(v, dict):
flattened.update((k + "." + vk, vv) for vk, vv in flatten(v).items())
else:
flattened[k] = v
return flattened
@cached
def get_non_ecs_schema():
"""Load non-ecs schema."""
return load_etc_dump('non-ecs-schema.json')
@cached
def get_index_schema(index_name):
return get_non_ecs_schema().get(index_name, {})
def flatten_multi_fields(schema):
converted = {}
for field, info in schema.items():
converted[field] = info["type"]
for subfield in info.get("multi_fields", []):
converted[field + "." + subfield["name"]] = subfield["type"]
return converted
@cached
def get_kql_schema(version=None, indexes=None, beat_schema=None):
"""Get schema for KQL."""
indexes = indexes or ()
converted = flatten_multi_fields(get_schema(version, name='ecs_flat'))
for index_name in indexes:
converted.update(**flatten(get_index_schema(index_name)))
if isinstance(beat_schema, dict):
converted = dict(flatten_multi_fields(beat_schema), **converted)
return converted
def download_schemas(refresh_master=True, refresh_all=False, verbose=True):
"""Download additional schemas from ecs releases."""
existing = [Version(v) for v in get_schema_map()] if not refresh_all else []
url = 'https://api.github.com/repos/elastic/ecs/releases'
releases = requests.get(url)
for release in releases.json():
version = Version(release.get('tag_name', '').lstrip('v'))
# we don't ever want beta
if not version or version < (1, 0, 1) or version in existing:
continue
schema_dir = os.path.join(ECS_SCHEMAS_DIR, str(version))
with unzip(requests.get(release['zipball_url']).content) as archive:
name_list = archive.namelist()
base = name_list[0]
# members = [m for m in name_list if m.startswith('{}{}/'.format(base, 'use-cases')) and m.endswith('.yml')]
members = ['{}generated/ecs/ecs_flat.yml'.format(base), '{}generated/ecs/ecs_nested.yml'.format(base)]
for member in members:
file_name = os.path.basename(member)
os.makedirs(schema_dir, exist_ok=True)
with open(os.path.join(schema_dir, file_name), 'wb') as f:
f.write(archive.read(member))
if verbose:
print('Saved files to {}: \n\t- {}'.format(schema_dir, '\n\t- '.join(members)))
# handle working master separately
if refresh_master:
master_ver = requests.get('https://raw.githubusercontent.com/elastic/ecs/master/version')
master_ver = Version(master_ver.text.strip())
master_schema = requests.get('https://raw.githubusercontent.com/elastic/ecs/master/generated/ecs/ecs_flat.yml')
master_schema = yaml.safe_load(master_schema.text)
# prepend with underscore so that we can differentiate the fact that this is a working master version
# but first clear out any existing masters, since we only ever want 1 at a time
existing_master = glob.glob(os.path.join(ECS_SCHEMAS_DIR, 'master_*'))
for m in existing_master:
shutil.rmtree(m, ignore_errors=True)
master_dir = os.path.join(ECS_SCHEMAS_DIR, 'master_{}'.format(master_ver))
master_file = os.path.join(master_dir, 'ecs_flat.yml')
os.makedirs(master_dir, exist_ok=True)
with open(master_file, 'w') as f:
yaml.safe_dump(master_schema, f)
if verbose:
print('Saved files to {}: \n\t- {}'.format(master_dir, 'ecs_flat.yml'))
+222
View File
@@ -0,0 +1,222 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Elasticsearch cli and tmp."""
import json
import os
import time
import click
from elasticsearch import AuthenticationException, Elasticsearch
from .main import root
from .misc import set_param_values
from .utils import normalize_timing_and_sort, unix_time_to_formatted, get_path
from .rule_loader import get_rule, rta_mappings
COLLECTION_DIR = get_path('collections')
ERRORS = {
'NO_EVENTS': 1,
'FAILED_ES_AUTH': 2
}
@root.group('es')
def es_group():
"""Helper commands for integrating with Elasticsearch."""
def get_es_client(user, password, host=None, cloud_id=None, **kwargs):
"""Get an auth-validated elsticsearch client."""
assert host or cloud_id, 'You must specify a host or cloud-id to authenticate to elasticsearch instance'
hosts = [host] if host else host
client = Elasticsearch(hosts=hosts, cloud_id=cloud_id, http_auth=(user, password), **kwargs)
# force login to test auth
client.info()
return client
class Events(object):
"""Events collected from Elasticsearch."""
def __init__(self, agent_hostname, events):
self.agent_hostname = agent_hostname
self.events = self._normalize_event_timing(events)
@staticmethod
def _normalize_event_timing(events):
"""Normalize event timestamps and sort."""
for agent_type, _events in events.items():
events[agent_type] = normalize_timing_and_sort(_events)
return events
def _get_dump_dir(self, rta_name=None):
"""Prepare and get the dump path."""
if rta_name:
dump_dir = get_path('unit_tests', 'data', 'true_positives', rta_name)
os.makedirs(dump_dir, exist_ok=True)
return dump_dir
else:
time_str = time.strftime('%Y%m%dT%H%M%SL')
dump_dir = os.path.join(COLLECTION_DIR, self.agent_hostname, time_str)
os.makedirs(dump_dir, exist_ok=True)
return dump_dir
def evaluate_against_rule_and_update_mapping(self, rule_id, rta_name, verbose=True):
"""Evaluate a rule against collected events and update mapping."""
from .utils import combine_sources, evaluate
rule = get_rule(rule_id, verbose=False)
merged_events = combine_sources(*self.events.values())
filtered = evaluate(rule, merged_events)
if filtered:
sources = [e['agent']['type'] for e in filtered]
mapping_update = rta_mappings.add_rule_to_mapping_file(rule, len(filtered), rta_name, *sources)
if verbose:
click.echo('Updated rule-mapping file with: \n{}'.format(json.dumps(mapping_update, indent=2)))
else:
if verbose:
click.echo('No updates to rule-mapping file; No matching results')
def echo_events(self, pager=False, pretty=True):
"""Print events to stdout."""
echo_fn = click.echo_via_pager if pager else click.echo
echo_fn(json.dumps(self.events, indent=2 if pretty else None, sort_keys=True))
def save(self, rta_name=None, dump_dir=None):
"""Save collected events."""
assert self.events, 'Nothing to save. Run Collector.run() method first'
dump_dir = dump_dir or self._get_dump_dir(rta_name)
for source, events in self.events.items():
path = os.path.join(dump_dir, source + '.jsonl')
with open(path, 'w') as f:
f.writelines([json.dumps(e, sort_keys=True) + '\n' for e in events])
click.echo('{} events saved to: {}'.format(len(events), path))
class CollectEvents(object):
"""Event collector for elastic stack."""
def __init__(self, client, max_events=3000):
self.client = client
self.MAX_EVENTS = max_events
def _build_timestamp_map(self, index_str):
"""Build a mapping of indexes to timestamp data formats."""
mappings = self.client.indices.get_mapping(index=index_str)
timestamp_map = {n: m['mappings'].get('properties', {}).get('@timestamp', {}) for n, m in mappings.items()}
return timestamp_map
def _get_current_time(self, agent_hostname, index_str):
"""Get timestamp of most recent event."""
# https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-date-format.html
timestamp_map = self._build_timestamp_map(index_str)
last_event = self._search_window(agent_hostname, index_str, start_time='now-1m', size=1, sort='@timestamp:desc')
last_event = last_event['hits']['hits'][0]
index = last_event['_index']
timestamp = last_event['_source']['@timestamp']
event_date_format = timestamp_map[index].get('format', '').split('||')
# there are many native supported date formats and even custom data formats, but most, including beats use the
# default `strict_date_optional_time`. It would be difficult to try to account for all possible formats, so this
# will work on the default and unix time.
if set(event_date_format) & {'epoch_millis', 'epoch_second'}:
timestamp = unix_time_to_formatted(timestamp)
return timestamp
def _search_window(self, agent_hostname, index_str, start_time, end_time='now', size=None, sort='@timestamp:asc',
**match):
"""Collect all events within a time window and parse by source."""
match = match.copy()
match.update({"agent.hostname": agent_hostname})
body = {"query": {"bool": {"filter": [
{"match": {"agent.hostname": agent_hostname}},
{"range": {"@timestamp": {"gt": start_time, "lte": end_time, "format": "strict_date_optional_time"}}}]
}}}
if match:
body['query']['bool']['filter'].extend([{'match': {k: v}} for k, v in match.items()])
return self.client.search(index=index_str, body=body, size=size or self.MAX_EVENTS, sort=sort)
@staticmethod
def _group_events_by_type(events):
"""Group events by agent.type."""
event_by_type = {}
for event in events['hits']['hits']:
event_by_type.setdefault(event['_source']['agent']['type'], []).append(event['_source'])
return event_by_type
def run(self, agent_hostname, indexes, verbose=True, **match):
"""Collect the events."""
index_str = ','.join(indexes)
start_time = self._get_current_time(agent_hostname, index_str)
if verbose:
click.echo('Setting start of event capture to: {}'.format(click.style(start_time, fg='yellow')))
click.pause('Press any key once detonation is complete ...')
time.sleep(5)
events = self._group_events_by_type(self._search_window(agent_hostname, index_str, start_time, **match))
return Events(agent_hostname, events)
@es_group.command('collect-events')
@click.argument('agent-hostname')
@click.option('--host', callback=set_param_values, expose_value=True)
@click.option('--cloud-id', callback=set_param_values, expose_value=True)
@click.option('--user', '-u', callback=set_param_values, expose_value=True, hide_input=False)
@click.option('--password', '-p', callback=set_param_values, expose_value=True, hide_input=True)
@click.option('--index', '-i', multiple=True, help='Index(es) to search against (default: all indexes)')
@click.option('--agent-type', '-a', help='Restrict results to a source type (agent.type) ex: auditbeat')
@click.option('--rta-name', '-r', help='Name of RTA in order to save events directly to unit tests data directory')
@click.option('--rule-id', help='Updates rule mapping in rule-mapping.yml file (requires --rta-name)')
@click.option('--view-events', is_flag=True, help='Print events after saving')
def collect_events(agent_hostname, host, cloud_id, user, password, index, agent_type, rta_name, rule_id, view_events):
"""Collect events from Elasticsearch."""
match = {'agent.type': agent_type} if agent_type else {}
try:
client = get_es_client(host=host, use_ssl=True, cloud_id=cloud_id, user=user, password=password)
except AuthenticationException:
click.secho('Failed authentication for {}'.format(host or cloud_id), fg='red', err=True)
return ERRORS['FAILED_ES_AUTH']
try:
collector = CollectEvents(client)
events = collector.run(agent_hostname, index, **match)
events.save(rta_name)
except AssertionError:
click.secho('No events collected! Verify events are streaming and that the agent-hostname is correct',
err=True, fg='red')
return ERRORS['NO_EVENTS']
if rta_name and rule_id:
events.evaluate_against_rule_and_update_mapping(rule_id, rta_name)
if view_events and events.events:
events.echo_events(pager=True)
return events
@es_group.command('normalize-data')
@click.argument('events-file', type=click.File('r'))
def normalize_file(events_file):
"""Normalize Elasticsearch data timestamps and sort."""
file_name = os.path.splitext(os.path.basename(events_file.name))[0]
events = Events('_', {file_name: [json.loads(e) for e in events_file.readlines()]})
events.save(dump_dir=os.path.dirname(events_file.name))
+352
View File
@@ -0,0 +1,352 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""CLI commands for detection_rules."""
import glob
import io
import json
import os
import click
import jsonschema
import pytoml
from eql import load_dump
from .misc import nested_set
from . import rule_loader
from .packaging import PACKAGE_FILE, Package, manage_versions
from .rule import RULE_TYPE_OPTIONS, Rule
from .rule_formatter import toml_write
from .utils import get_path, clear_caches
RULES_DIR = get_path('rules')
@click.group('detection-rules', context_settings={'help_option_names': ['-h', '--help']})
def root():
"""Commands for detection-rules repository."""
@root.command('create-rule')
@click.argument('path', type=click.Path(dir_okay=False))
@click.option('--config', '-c', type=click.Path(exists=True, dir_okay=False), help='Rule or config file')
@click.option('--required-only', is_flag=True, help='Only prompt for required fields')
@click.option('--rule-type', '-t', type=click.Choice(RULE_TYPE_OPTIONS), help='Type of rule to create')
def create_rule(path, config, required_only, rule_type):
"""Create a detection rule."""
config = load_dump(config) if config else {}
try:
return Rule.build(path, rule_type=rule_type, required_only=required_only, save=True, **config)
finally:
rule_loader.reset()
@root.command('load-from-file')
@click.argument('infile', type=click.Path(dir_okay=False, exists=True), nargs=-1, required=False)
@click.option('--directory', '-d', type=click.Path(file_okay=False, exists=True), help='Load files from a directory')
def load_from_file(infile, directory):
"""Load rules from file(s)."""
if infile:
for rule_file in infile:
rule_path = os.path.join(RULES_DIR, os.path.basename(rule_file))
rule = Rule(rule_path, load_dump(rule_file))
rule.save(as_rule=True, verbose=True)
elif directory:
for rule_file in glob.glob(os.path.join(directory, '**', '*.*'), recursive=True):
try:
rule_path = os.path.join(RULES_DIR, os.path.basename(rule_file))
rule = Rule(rule_path, load_dump(rule_file))
rule.save(as_rule=True, verbose=True)
except ValueError:
click.echo('Unable to load file: {}'.format(rule_file))
else:
click.echo('No files specified!')
@root.command('toml-lint')
@click.option('--rule-file', '-f', type=click.File('r'), help='Optionally specify a specific rule file only')
def toml_lint(rule_file):
"""Cleanup files with some simple toml formatting."""
if rule_file:
contents = pytoml.load(rule_file)
rule = Rule(path=rule_file.name, contents=contents)
# removed unneeded defaults
for field in rule_loader.find_unneeded_defaults(rule):
rule.contents.pop(field, None)
rule.save(as_rule=True)
else:
for rule in rule_loader.load_rules().values():
# removed unneeded defaults
for field in rule_loader.find_unneeded_defaults(rule):
rule.contents.pop(field, None)
rule.save(as_rule=True)
rule_loader.reset()
click.echo('Toml file linting complete')
@root.command('mass-update')
@click.argument('query')
@click.option('--field', type=(str, str), multiple=True,
help='Use rule-search to retrieve a subset of rules and modify values '
'(ex: --field management.ecs_version 1.1.1).\n'
'Note this is limited to string fields only. Nested fields should use dot notation.')
@click.pass_context
def mass_update(ctx, query, field):
"""Update multiple rules based on eql results."""
results = ctx.invoke(search_rules, query=query, verbose=False)
rules = [rule_loader.get_rule(r['rule_id']) for r in results]
for rule in rules:
for key, value in field:
nested_set(rule.contents, key, value)
rule.validate(as_rule=True)
rule.save()
ctx.invoke(search_rules, query=query, columns=[k[0].split('.')[-1] for k in field])
return
@root.command('view-rule')
@click.argument('rule-id', required=False)
@click.option('--rule-file', '-f', type=click.Path(dir_okay=False), help='Optionally view a rule from a specified file')
@click.option('--as-api/--as-rule', default=True, help='Print the rule in final api or rule format')
@click.option('--optimize/--no-optimize', default=False, help='When viewing in api format, include optimizations')
def view_rule(rule_id, rule_file, as_api, optimize):
"""View an internal rule or specified rule file."""
if rule_id:
rule = rule_loader.get_rule(rule_id, verbose=False)
elif rule_file:
rule = Rule(rule_file, load_dump(rule_file))
else:
click.secho('Unknown rule!', fg='red')
return
if not rule:
click.secho('Unknown format!', fg='red')
return
if optimize and as_api:
rule.tune()
click.echo(toml_write(rule.rule_format()) if not as_api else json.dumps(rule.contents, indent=2, sort_keys=True))
return rule
@root.command('validate-rule')
@click.argument('rule-id', required=False)
@click.option('--rule-name', '-n')
@click.option('--path', '-p', type=click.Path(dir_okay=False))
def validate_rule(rule_id, rule_name, path):
"""Check if a rule staged in rules dir validates against a schema."""
rule = rule_loader.get_rule(rule_id, rule_name, path, verbose=False)
if not rule:
return click.secho('Rule not found!', fg='red')
try:
rule.validate(as_rule=True)
except jsonschema.ValidationError as e:
click.echo(e)
click.echo('Rule validation successful')
return rule
license_header = """
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
""".strip()
@root.command('license-check')
@click.pass_context
def license_check(ctx):
"""Check that all code files contain a valid license."""
failed = False
for path in glob.glob(get_path("**", "*.py"), recursive=True):
if path.startswith(get_path("env", "")):
continue
relative_path = os.path.relpath(path)
with io.open(path, "rt", encoding="utf-8") as f:
contents = f.read()
# skip over shebang lines
if contents.startswith("#!/"):
_, _, contents = contents.partition("\n")
if not contents.lstrip("\r\n").startswith(license_header):
if not failed:
click.echo("Missing license headers for:", err=True)
failed = True
click.echo(relative_path, err=True)
ctx.exit(int(failed))
@root.command('validate-all')
@click.option('--fail/--no-fail', default=True, help='Fail on first failure or process through all printing errors.')
def validate_all(fail):
"""Check if all rules validates against a schema."""
rule_loader.load_rules(verbose=True, error=fail)
click.echo('Rule validation successful')
@root.command('rule-search')
@click.argument('query', required=False)
@click.option('--columns', '-c', multiple=True, help='Specify columns to add the table')
@click.option('--language', type=click.Choice(["eql", "kql"]), default="kql")
def search_rules(query, columns, language, verbose=True):
"""Use KQL to find matching rules."""
from kql import get_evaluator
from eql.table import Table
from eql.build import get_engine
from eql import parse_query
from eql.pipes import CountPipe
flattened_rules = []
for file_name, rule_doc in rule_loader.load_rule_files().items():
flat = {"file": os.path.relpath(file_name)}
flat.update(rule_doc)
flat.update(rule_doc["metadata"])
flat.update(rule_doc["rule"])
attacks = [threat for threat in rule_doc["rule"].get("threat", []) if threat["framework"] == "MITRE ATT&CK"]
techniques = [t["id"] for threat in attacks for t in threat.get("technique", [])]
tactics = [threat["tactic"]["name"] for threat in attacks]
flat.update(techniques=techniques, tactics=tactics)
flattened_rules.append(flat)
flattened_rules.sort(key=lambda dct: dct["name"])
if language == "kql":
evaluator = get_evaluator(query) if query else lambda x: True
filtered = list(filter(evaluator, flattened_rules))
elif language == "eql":
parsed = parse_query(query, implied_any=True, implied_base=True)
evaluator = get_engine(parsed)
filtered = [result.events[0].data for result in evaluator(flattened_rules)]
if not columns and any(isinstance(pipe, CountPipe) for pipe in parsed.pipes):
columns = ["key", "count", "percent"]
if columns:
columns = ",".join(columns).split(",")
else:
columns = ["rule_id", "file", "name"]
table = Table.from_list(columns, filtered)
if verbose:
click.echo(table)
return filtered
@root.command('build-release')
@click.argument('config-file', type=click.Path(exists=True, dir_okay=False), required=False, default=PACKAGE_FILE)
@click.option('--update-version-lock', '-u', is_flag=True,
help='Save version.lock.json file with updated rule versions in the package')
def build_release(config_file, update_version_lock):
"""Assemble all the rules into Kibana-ready release files."""
config = load_dump(config_file)['package']
click.echo('[+] Building package {}'.format(config.get('name')))
package = Package.from_config(config, update_version_lock=update_version_lock)
package.save()
package.get_package_hash(verbose=True)
click.echo('- {} rules included'.format(len(package.rules)))
@root.command('update-lock-versions')
@click.argument('rule-ids', nargs=-1, required=True)
def update_lock_versions(rule_ids):
"""Update rule hashes in version.lock.json file without bumping version."""
from .packaging import manage_versions
if not click.confirm('Are you sure you want to update hashes without a version bump?'):
return
rules = [r for r in rule_loader.load_rules(verbose=False).values() if r.id in rule_ids]
changed, new = manage_versions(rules, exclude_version_update=True, add_new=False, save_changes=True)
if not changed:
click.echo('No hashes updated')
return changed
@root.command('kibana-diff')
@click.option('--rule-id', '-r', multiple=True, help='Optionally specify rule ID')
@click.option('--branch', '-b', default='master', help='Specify the kibana branch to diff against')
def kibana_diff(rule_id, branch):
"""Diff rules against their version represented in kibana if exists."""
from .misc import get_kibana_rules
if rule_id:
rules = [r for r in rule_loader.load_rules(verbose=False).values() if r.id in rule_id]
else:
rules = [r for r in rule_loader.load_rules(verbose=False).values() if r.metadata['maturity'] == 'production']
# add versions to the rules
manage_versions(rules, verbose=False)
rule_paths = [os.path.basename(r.path) for r in rules]
try:
original_gh_rules = get_kibana_rules(*rule_paths, branch=branch).values()
except ValueError as e:
click.secho(e.args[0], fg='red', err=True)
return
gh_rule_versions = {r['rule_id']: r.pop('version') for r in original_gh_rules}
rule_versions = {r.id: r.contents.pop('version') for r in rules}
gh_rules = {r['rule_id']: Rule('_', r) for r in original_gh_rules}
rule_ids = [r.id for r in rules]
gh_rule_ids = [r.id for r in gh_rules.values()]
missing_rules = [r for r in gh_rules.values() if r.id in list(set(gh_rule_ids).difference(set(rule_ids)))]
diff = {
'missing_from_kibana': [],
'diff': [],
'missing_from_rules': ['{} - {}'.format(r.id, r.name) for r in missing_rules]
}
for rule in rules:
if rule.id not in gh_rule_ids:
diff['missing_from_kibana'].append('{} - {}'.format(rule.id, rule.name))
continue
gh_rule = gh_rules[rule.id]
if rule.get_hash() != gh_rule.get_hash():
diff['diff'].append('versions - repo: {}, kibana: {} -> {} - {}'.format(
rule_versions[rule.id], gh_rule_versions[rule.id], rule.id, rule.name))
click.echo(json.dumps(diff, indent=2, sort_keys=True))
@root.command("test")
@click.pass_context
def test_rules(ctx):
"""Run unit tests over all of the rules."""
import pytest
clear_caches()
ctx.exit(pytest.main(["-v"]))
+75
View File
@@ -0,0 +1,75 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""RTA to rule mappings."""
import os
from collections import defaultdict
from .schema import validate_rta_mapping
from .utils import load_etc_dump, save_etc_dump, get_path
RTA_DIR = get_path("rta")
class RtaMappings(object):
"""Rta-mapping helper class."""
def __init__(self):
"""Rta-mapping validation and prep."""
self.mapping = load_etc_dump('rule-mapping.yml') # type: dict
self.validate()
self._rta_mapping = defaultdict(list)
self._remote_rta_mapping = {}
self._rule_mappings = {}
def validate(self):
"""Validate mapping against schema."""
for k, v in self.mapping.items():
validate_rta_mapping(v)
def add_rule_to_mapping_file(self, rule, rta_name, count=0, *sources):
"""Insert a rule mapping into the mapping file."""
mapping = self.mapping
rule_map = {
'count': count,
'rta_name': rta_name,
'rule_name': rule.name,
}
if sources:
rule_map['sources'] = list(sources)
mapping[rule.id] = rule_map
self.mapping = dict(sorted(mapping.items()))
save_etc_dump(self.mapping, 'rule-mapping.yml')
return rule_map
def get_rta_mapping(self):
"""Build the rule<-->rta mapping based off the mapping file."""
if not self._rta_mapping:
self._rta_mapping = {rule_id: data['rta'] for rule_id, data in self.mapping.items()}
return self._rta_mapping
def get_rta_files(self, rta_list=None, rule_ids=None):
"""Get the full paths to RTA files, given a list of names or rule ids."""
full_rta_mapping = self.get_rta_mapping()
rta_files = set()
rta_list = set(rta_list or [])
if rule_ids:
for rule_id, rta_map in full_rta_mapping.items():
if rule_id in rule_ids:
rta_list.update(rta_map)
for rta_name in rta_list:
# rip off the extension and add .py
rta_name, _ = os.path.splitext(os.path.basename(rta_name))
rta_path = os.path.abspath(os.path.join(RTA_DIR, rta_name + ".py"))
if os.path.exists(rta_path):
rta_files.add(rta_path)
return list(sorted(rta_files))
+197
View File
@@ -0,0 +1,197 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Misc support."""
import json
import os
import re
import time
import uuid
import click
import requests
from .utils import ROOT_DIR
_CONFIG = {}
def nested_get(_dict, dot_key, default=None):
"""Get a nested field from a nested dict with dot notation."""
if _dict is None or dot_key is None:
return default
elif '.' in dot_key and isinstance(_dict, dict):
dot_key = dot_key.split('.')
this_key = dot_key.pop(0)
return nested_get(_dict.get(this_key, default), '.'.join(dot_key), default)
else:
return _dict.get(dot_key, default)
def nested_set(_dict, dot_key, value):
"""Set a nested field from a a key in dot notation."""
for key in dot_key.split('.')[:-1]:
_dict = _dict.setdefault(key, {})
if isinstance(_dict, dict):
_dict[dot_key[-1]] = value
else:
raise ValueError('dict cannot set a value to a non-dict for {}'.format(dot_key))
def schema_prompt(name, value=None, required=False, **options):
"""Interactively prompt based on schema requirements."""
name = str(name)
field_type = options.get('type')
pattern = options.get('pattern')
enum = options.get('enum', [])
minimum = options.get('minimum')
maximum = options.get('maximum')
min_item = options.get('min_items', 0)
max_items = options.get('max_items', 9999)
default = options.get('default')
if default is not None and str(default).lower() in ('true', 'false'):
default = str(default).lower()
if 'date' in name:
default = time.strftime('%Y/%m/%d')
if name == 'rule_id':
default = str(uuid.uuid4())
def _check_type(_val):
if field_type in ('number', 'integer') and not str(_val).isdigit():
print('Number expected but got: {}'.format(_val))
return False
if pattern and (not re.match(pattern, _val) or len(re.match(pattern, _val).group(0)) != len(_val)):
print('{} did not match pattern: {}!'.format(_val, pattern))
return False
if enum and _val not in enum:
print('{} not in valid options: {}'.format(_val, ', '.join(enum)))
return False
if minimum and (type(_val) == int and int(_val) < minimum):
print('{} is less than the minimum: {}'.format(str(_val), str(minimum)))
return False
if maximum and (type(_val) == int and int(_val) > maximum):
print('{} is greater than the maximum: {}'.format(str(_val), str(maximum)))
return False
if field_type == 'boolean' and _val.lower() not in ('true', 'false'):
print('Boolean expected but got: {}'.format(str(_val)))
return False
return True
def _convert_type(_val):
if field_type == 'boolean' and not type(_val) == bool:
_val = True if _val.lower() == 'true' else False
return int(_val) if field_type in ('number', 'integer') else _val
prompt = '{name}{default}{required}{multi}'.format(
name=name,
default=' [{}] ("n/a" to leave blank) '.format(default) if default else '',
required=' (required) ' if required else '',
multi=' (multi, comma separated) ' if field_type == 'array' else '').strip() + ': '
while True:
result = value or input(prompt) or default
if result == 'n/a':
result = None
if not result:
if required:
value = None
continue
else:
return
if field_type == 'array':
result_list = result.split(',')
if not (min_item < len(result_list) < max_items):
if required:
value = None
break
else:
return []
for value in result_list:
if not _check_type(value):
if required:
value = None
break
else:
return []
return [_convert_type(r) for r in result_list]
else:
if _check_type(result):
return _convert_type(result)
elif required:
value = None
continue
return
def get_kibana_rules_map(branch='master'):
"""Get list of available rules from the Kibana repo and return a list of URLs."""
r = requests.get('https://api.github.com/repos/elastic/kibana/branches?per_page=1000')
branch_names = [b['name'] for b in r.json()]
if branch not in branch_names:
raise ValueError('branch "{}" does not exist in kibana'.format(branch))
url = ('https://api.github.com/repos/elastic/kibana/contents/x-pack/{legacy}plugins/siem/server/lib/'
'detection_engine/rules/prepackaged_rules?ref={branch}')
gh_rules = requests.get(url.format(legacy='', branch=branch)).json()
# pre-7.8 the siem was under the legacy directory
if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found':
gh_rules = requests.get(url.format(legacy='legacy/', branch=branch)).json()
return {os.path.splitext(r['name'])[0]: r['download_url'] for r in gh_rules if r['name'].endswith('.json')}
def get_kibana_rules(*rule_paths, branch='master', verbose=True):
"""Retrieve prepackaged rules from kibana repo."""
if verbose:
click.echo('Downloading rules from {} branch in kibana repo...'.format(branch))
if rule_paths:
rule_paths = [os.path.splitext(os.path.basename(p))[0] for p in rule_paths]
return {n: requests.get(r).json() for n, r in get_kibana_rules_map(branch).items() if n in rule_paths}
else:
return {n: requests.get(r).json() for n, r in get_kibana_rules_map(branch).items()}
def parse_config():
"""Parse a default config file."""
global _CONFIG
if not _CONFIG:
config_file = os.path.join(ROOT_DIR, '.siem-rules-cfg.json')
if os.path.exists(config_file):
with open(config_file) as f:
_CONFIG = json.load(f)
click.secho('Loaded config file: {}'.format(config_file), fg='yellow')
return _CONFIG
def set_param_values(ctx, param, value):
"""Get value for defined key."""
key = param.name
config = parse_config()
env_key = 'SR_' + key
prompt = True if param.hide_input is not False else False
if value:
return value
elif os.environ.get(env_key):
return os.environ[env_key]
elif config.get(key):
return config[key]
elif prompt:
return click.prompt(key, default=param.default if not param.default else None, hide_input=param.hide_input,
show_default=True if param.default else False)
+252
View File
@@ -0,0 +1,252 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Packaging and preparation for releases."""
import base64
import hashlib
import json
import os
import shutil
from collections import OrderedDict
import click
from . import rule_loader
from .rule import Rule # noqa: F401
from .utils import get_path, get_etc_path
RELEASE_DIR = get_path("releases")
PACKAGE_FILE = get_etc_path('packages.yml')
RULE_VERSIONS = get_etc_path('version.lock.json')
def filter_rule(rule, config_filter): # type: (Rule,dict) -> bool # rule.contents (not api), filter_dict -> match
"""Filter a rule based off metadata and a package configuration."""
flat_rule = rule.flattened_contents
for key, values in config_filter.items():
if key not in flat_rule:
return False
values = set([v.lower() if isinstance(v, str) else v for v in values])
rule_value = flat_rule[key]
if isinstance(rule_value, list):
rule_values = {v.lower() if isinstance(v, str) else v for v in rule_value}
else:
rule_values = {rule_value.lower() if isinstance(rule_value, str) else rule_value}
if len(rule_values & values) == 0:
return False
return True
def manage_versions(rules, current_versions=None, exclude_version_update=False, add_new=True, save_changes=False,
verbose=True):
# type: (list, dict, bool, bool, bool, bool) -> [list, list]
"""Update the contents of the version.lock file and optionally save changes."""
new_rules = {}
changed_rules = []
if current_versions is None:
with open(RULE_VERSIONS, 'r') as f:
current_versions = json.load(f)
for rule in rules:
# it is a new rule, so add it if specified, and add an initial version to the rule
if rule.id not in current_versions:
new_rules[rule.id] = {'rule_name': rule.name, 'version': 1, 'sha256': rule.get_hash()}
rule.contents['version'] = 1
else:
version_lock_info = current_versions.get(rule.id)
version = version_lock_info['version']
rule_hash = rule.get_hash()
# if it has been updated, then we need to bump the version info and optionally save the changes later
if rule_hash != version_lock_info['sha256']:
rule.contents['version'] = version + 1
if not exclude_version_update:
version_lock_info['version'] = rule.contents['version']
version_lock_info.update(sha256=rule_hash, rule_name=rule.name)
changed_rules.append(rule.id)
else:
rule.contents['version'] = version
# update the document with the new rules
if new_rules or changed_rules:
if verbose:
click.echo('Rule hash changes detected!')
if save_changes:
current_versions.update(new_rules if add_new else {})
current_versions = OrderedDict(sorted(current_versions.items(), key=lambda x: x[1]['rule_name']))
with open(RULE_VERSIONS, 'w') as f:
json.dump(current_versions, f, indent=2, sort_keys=True)
if verbose:
click.echo('Updated version.lock.json file with:')
else:
if verbose:
click.echo('run `build-release --update-version-lock` to update the version.lock.json file')
if verbose:
if changed_rules:
click.echo(' - {} changed rule version(s)'.format(len(changed_rules)))
if new_rules:
click.echo(' - {} new rule version addition(s)'.format(len(new_rules)))
return changed_rules, new_rules.keys()
class Package(object):
"""Packaging object for siem rules and releases."""
def __init__(self, rules, name, tune=False, release=False, current_versions=None, min_version=None,
max_version=None, update_version_lock=False):
"""Initialize a package."""
self.rules = [r.copy() for r in rules] # type: list[Rule]
self.name = name
self.release = release
self.changed_rules, self.new_rules = self._add_versions(current_versions, update_version_lock)
if min_version or max_version:
self.rules = [r for r in self.rules
if (min_version or 0) <= r.contents['version'] <= (max_version or r.contents['version'])]
if tune:
for rule in rules:
rule.tune()
def _add_versions(self, current_versions, update_versions_lock=False):
"""Add versions to rules at load time."""
return manage_versions(self.rules, current_versions=current_versions, save_changes=update_versions_lock)
def save_release_files(self, directory, changed_rules, new_rules):
"""Release a package."""
# TODO:
# xslx of mitre coverage
# release notes
with open(os.path.join(directory, '{}-summary.txt'.format(self.name)), 'w') as f:
f.write(self.generate_summary(changed_rules, new_rules))
with open(os.path.join(directory, '{}-consolidated.json'.format(self.name)), 'w') as f:
json.dump(json.loads(self.get_consolidated()), f, sort_keys=True, indent=2)
def get_consolidated(self, as_api=True):
"""Get a consolidated package of the rules in a single file."""
full_package = []
for rule in self.rules:
full_package.append(rule.contents if as_api else rule.rule_format())
return json.dumps(full_package, sort_keys=True)
def save(self, verbose=True):
"""Save a package and all artifacts."""
save_dir = os.path.join(RELEASE_DIR, self.name)
rules_dir = os.path.join(save_dir, 'rules')
extras_dir = os.path.join(save_dir, 'extras')
# remove anything that existed before
shutil.rmtree(save_dir, ignore_errors=True)
os.makedirs(rules_dir, exist_ok=True)
os.makedirs(extras_dir, exist_ok=True)
for rule in self.rules:
rule.save(new_path=os.path.join(rules_dir, os.path.basename(rule.path)))
if self.release:
self.save_release_files(extras_dir, self.changed_rules, self.new_rules)
# zip all rules only and place in extras
shutil.make_archive(os.path.join(extras_dir, self.name), 'zip', root_dir=os.path.dirname(rules_dir),
base_dir=os.path.basename(rules_dir))
# zip everything and place in release root
shutil.make_archive(os.path.join(save_dir, '{}-all'.format(self.name)), 'zip',
root_dir=os.path.dirname(extras_dir), base_dir=os.path.basename(extras_dir))
if verbose:
click.echo('Package saved to: {}'.format(save_dir))
def from_github(self):
"""Retrieve previously released and staged packages."""
def get_package_hash(self, as_api=True, verbose=True):
"""Get hash of package contents."""
contents = base64.b64encode(self.get_consolidated(as_api=as_api).encode('utf-8'))
sha256 = hashlib.sha256(contents).hexdigest()
if verbose:
click.echo('- sha256: {}'.format(sha256))
return sha256
@classmethod
def from_config(cls, config=None, update_version_lock=False): # type: (dict, bool) -> Package
"""Load a rules package given a config."""
all_rules = rule_loader.load_rules(verbose=False).values()
config = config or {}
rule_filter = config.pop('filter', {})
min_version = config.pop('min_version', None)
max_version = config.pop('max_version', None)
rules = filter(lambda rule: filter_rule(rule, rule_filter), all_rules)
update = config.pop('update', {})
package = cls(rules, min_version=min_version, max_version=max_version, update_version_lock=update_version_lock,
**config)
# Allow for some fields to be overwritten
if update.get('data', {}):
for rule in package.rules:
for sub_dict, values in update.items():
rule.contents[sub_dict].update(values)
return package
def generate_summary(self, changed_rules, new_rules):
"""Generate stats on package."""
ecs_versions = set()
indices = set()
changed = []
new = []
for rule in self.rules:
ecs_versions.update(rule.ecs_version)
indices.update(rule.contents.get('index', ''))
if rule.id in changed_rules:
changed.append('{} (v{})'.format(rule.name, rule.contents.get('version')))
elif rule.id in new_rules:
new.append('{} (v{})'.format(rule.name, rule.contents.get('version')))
total = 'Total Rules: {}'.format(len(self.rules))
sha256 = 'Package Hash: {}'.format(self.get_package_hash(verbose=False))
ecs_versions = 'ECS Versions: {}'.format(', '.join(ecs_versions))
indices = 'Included Indexes: {}'.format(', '.join(indices))
new_rules = 'New Rules: \n{}'.format('\n'.join(' - ' + s for s in sorted(new)) if new else 'N/A')
modified_rules = 'Modified Rules: \n{}'.format('\n'.join(' - ' + s for s in sorted(changed)) if new else 'N/A')
return '\n'.join([total, sha256, ecs_versions, indices, new_rules, modified_rules])
def generate_mitre(self):
"""Create an excel file based on mitre coverage."""
# mapping with highlights of covered cells - links to pivot table with technique id selected
def reconcile_changes(self):
"""Parse and generate changes since previous release based on changed.toml file."""
# at packaging, generate flat changes file to standard, based on consolidated and deduped interpretation of
# changed.toml and clear out changes.toml
# - all based on api_format only
# see packages.yml - can update management.changed = True:
# until released in package, then added with filter and changed to False
def generate_change_notes(self):
"""Generate change release notes."""
def bump_versions(self, save_changes=False, current_versions=None):
"""Bump the versions of all production rules included in a release and optionally save changes."""
return manage_versions(self.rules, current_versions=current_versions, save_changes=save_changes)
+357
View File
@@ -0,0 +1,357 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Rule object."""
import base64
import copy
import hashlib
import json
import os
import click
import kql
from . import ecs, beats
from .attack import TACTICS, build_threat_map_entry, technique_lookup
from .rule_formatter import nested_normalize, toml_write
from .schema import metadata_schema, schema_validate, get_schema
from .utils import get_path, clear_caches, cached
RULES_DIR = get_path("rules")
RULE_TYPE_OPTIONS = ['machine_learning', 'query', 'saved_id']
_META_SCHEMA_REQ_DEFAULTS = {}
class Rule(object):
"""Rule class containing all the information about a rule."""
def __init__(self, path, contents, tune=False):
"""Create a Rule from a toml management format."""
self.path = os.path.realpath(path)
self.contents = contents.get('rule', contents)
self.metadata = self.set_metadata(contents.get('metadata', contents))
self.formatted_rule = copy.deepcopy(self.contents).get('query', None)
self.validate()
self.unoptimized_query = self.contents.get('query')
if tune:
self.tune_rule = True
self.tune()
self._original_hash = self.get_hash()
def __str__(self):
return 'name={}, path={}, query={}'.format(self.name, self.path, self.query)
def __repr__(self):
return '{}(path={}, contents={}, tune={})'.format(type(self).__name__, repr(self.path), repr(self.contents),
repr(self.tune_rule))
def __eq__(self, other):
if type(self) == type(other):
return self.get_hash() == other.get_hash()
return False
def copy(self):
return Rule(path=self.path, contents={'rule': self.contents.copy(), 'metadata': self.metadata.copy()})
@property
def id(self):
return self.contents.get("rule_id")
@property
def name(self):
return self.contents.get("name")
@property
def query(self):
return self.contents.get('query')
@property
def parsed_kql(self):
if self.query and self.contents['language'] == 'kuery':
return kql.parse(self.query)
@property
def filters(self):
return self.contents.get('filters')
@property
def ecs_version(self):
return sorted(self.metadata.get('ecs_version', []))
@property
def flattened_contents(self):
return dict(self.contents, **self.metadata)
@property
def type(self):
return self.contents.get('type')
def to_eql(self):
if self.query and self.contents['language'] == 'kuery':
return kql.to_eql(self.query)
@staticmethod
@cached
def get_meta_schema_required_defaults():
"""Get the default values for required properties in the metadata schema."""
required = [v for v in metadata_schema['required']]
properties = {k: v for k, v in metadata_schema['properties'].items() if k in required}
return {k: v.get('default') or [v['items']['default']] for k, v in properties.items()}
def set_metadata(self, contents):
"""Parse metadata fields and set missing required fields to the default values."""
metadata = {k: v for k, v in contents.items() if k in metadata_schema['properties']}
defaults = self.get_meta_schema_required_defaults().copy()
defaults.update(metadata)
return defaults
def rule_format(self, formatted_query=True):
"""Get the contents in rule format."""
contents = self.contents.copy()
if formatted_query:
if self.formatted_rule:
contents['query'] = self.formatted_rule
return {'metadata': self.metadata, 'rule': contents}
def normalize(self, indent=2):
"""Normalize the (api only) contents and return a serialized dump of it."""
return json.dumps(nested_normalize(self.contents), sort_keys=True, indent=indent)
def tune(self):
"""Tune query by including applicable fields derived from metadata."""
# if not self.query:
# return
#
# self.unoptimized_query = self.contents.get('query')
#
# if not hasattr(self.parsed_query, 'terms'):
# # can prepend here if we want
# return
#
# # TODO: This is error prone and absolutely can/should be better done with a custom walker to:
# # - find these fields
# # - move them to the front/highest precedence
# # - dedup+update them with these values from metadata
# # I am going to leave it for now as a good mechanism for testing the theory and since it only impacts at
# # "package" time and will open an issue in the meantime
#
# # add os version
# # many os ecs fields - will optimize later
# # if not any(str(term.left) == '' for term in parsed_query.terms) and self.metadata.get('os_type_list'):
# # self.contents['query'] = ':({}) and '.format(' or '.join(self.metadata['_os_type_list'])) + self.query
#
# # add ecs version
# # handle these better with eql2kql
# compares = [str(term.left) == 'ecs.version' for term in self.parsed_query.terms
# if isinstance(term, Comparison)]
# in_sets = [str(term.expression) == 'ecs.version' for term in self.parsed_query.terms
# if isinstance(term, InSet)]
#
# if any(in_sets):
# pass
# elif any(compares):
# pass
# elif not (any(compares) or any(in_sets)):
# ecs_query = ' or '.join(self.metadata['ecs_version'])
# self.contents['query'] = 'ecs.version:({}) and '.format(ecs_query) + self.query
def untune(self):
"""Restore query to pre-tuned state."""
# self.contents['query'] = self.unoptimized_query
def get_path(self):
"""Wrapper around getting path."""
if not self.path:
raise ValueError('path not set for rule: \n\t{}'.format(self))
return self.path
def needs_save(self):
"""Determines if the rule was changed from original or was never saved."""
return self._original_hash != self.get_hash()
@classmethod # TODO
def from_eql_rule(cls, path, contents, validate=False):
"""Create a rule from loaded rule (toml) contents."""
# if validate:
# jsonschema.validate(contents, rule_schema)
return cls(path, contents)
def bump_version(self):
"""Bump the version of the rule."""
self.contents['version'] += 1
def validate(self, as_rule=False, versioned=False):
"""Validate against a rule schema, query schema, and linting."""
self.normalize()
if as_rule:
schema_validate(self.rule_format(), as_rule=True)
else:
schema_validate(self.contents, versioned=versioned)
if self.query and self.contents['language'] == 'kuery':
# validate against all specified schemas or the latest if none specified
ecs_versions = self.metadata.get('ecs_version')
indexes = self.contents.get("index", [])
beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index]
beat_schema = beats.get_schema_for_query(self.parsed_kql, beat_types) if beat_types else None
if not ecs_versions:
kql.parse(self.query, schema=ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema))
else:
for version in ecs_versions:
try:
schema = ecs.get_kql_schema(version=version, indexes=indexes, beat_schema=beat_schema)
except KeyError:
raise KeyError(
'Unknown ecs schema version: {} in rule {}.\n'
'Do you need to update schemas?'.format(version, self.name))
try:
kql.parse(self.query, schema=schema)
except kql.KqlParseError as exc:
message = exc.error_msg
trailer = None
if "Unknown field" in message and beat_types:
trailer = "\nTry adding event.module and event.dataset to specify beats module"
raise kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source,
len(exc.caret.lstrip()), trailer=trailer)
def save(self, new_path=None, as_rule=False, verbose=False):
"""Save as pretty toml rule file as toml."""
path, _ = os.path.splitext(new_path or self.get_path())
path += '.toml' if as_rule else '.json'
if as_rule:
toml_write(self.rule_format(), path)
else:
with open(path, 'w', newline='\n') as f:
json.dump(self.contents, f, sort_keys=True, indent=2)
f.write('\n')
if verbose:
print('Rule {} saved to {}'.format(self.name, path))
def get_hash(self):
"""Get a standardized hash of a rule to consistently check for changes."""
contents = base64.b64encode(json.dumps(self.contents, sort_keys=True).encode('utf-8'))
return hashlib.sha256(contents).hexdigest()
@classmethod
def build(cls, path=None, rule_type=None, required_only=True, save=True, **kwargs):
"""Build a rule from data and prompts."""
from .misc import schema_prompt
# from .rule_loader import rta_mappings
kwargs = copy.deepcopy(kwargs)
while rule_type not in RULE_TYPE_OPTIONS:
rule_type = click.prompt('Rule type ({})'.format(', '.join(RULE_TYPE_OPTIONS)))
schema = get_schema(rule_type)
props = schema['properties']
opt_reqs = schema.get('required', [])
contents = {}
skipped = []
for name, options in props.items():
if name == 'type':
contents[name] = rule_type
continue
# these are set at package release time
if name == 'version':
continue
if required_only and name not in opt_reqs:
continue
# build this from technique ID
if name == 'threat':
threat_map = []
while click.confirm('add mitre tactic?'):
tactic = schema_prompt('mitre tactic name', type='string', enum=TACTICS, required=True)
technique_ids = schema_prompt(f'technique IDs for {tactic}', type='array', required=True,
enum=list(technique_lookup))
threat_map.append(build_threat_map_entry(tactic, *technique_ids))
if len(threat_map) > 0:
contents[name] = threat_map
continue
if kwargs.get(name):
contents[name] = schema_prompt(kwargs.pop(name))
continue
result = schema_prompt(name, required=name in opt_reqs, **options)
if result:
if name not in opt_reqs and result == options.get('default', ''):
skipped.append(name)
continue
contents[name] = result
metadata = {}
ecs_version = schema_prompt('ecs_version', required=False, value=None,
**metadata_schema['properties']['ecs_version'])
if ecs_version:
metadata['ecs_version'] = ecs_version
# validate before creating
schema_validate(contents)
suggested_path = os.path.join(RULES_DIR, contents['name']) # TODO: UPDATE BASED ON RULE STRUCTURE
path = os.path.realpath(path or input('File path for rule [{}]: '.format(suggested_path)) or suggested_path)
rule = None
try:
rule = cls(path, {'rule': contents, 'metadata': metadata})
except kql.KqlParseError as e:
if e.error_msg == 'Unknown field':
warning = ('If using a non-ECS field, you must update "ecs{}.non-ecs-schema.json" under `beats` or '
'`legacy-endgame` (Non-ECS fields should be used minimally).'.format(os.path.sep))
click.secho(e.args[0], fg='red', err=True)
click.secho(warning, fg='yellow', err=True)
click.pause()
# if failing due to a query, loop until resolved or terminated
while True:
try:
contents['query'] = click.edit(contents['query'], extension='.eql')
rule = cls(path, {'rule': contents, 'metadata': metadata})
except kql.KqlParseError as e:
click.secho(e.args[0], fg='red', err=True)
click.pause()
if e.error_msg.startswith("Unknown field"):
# get the latest schema for schema errors
clear_caches()
ecs.get_kql_schema(indexes=contents.get("index", []))
continue
break
if save:
rule.save(verbose=True, as_rule=True)
if skipped:
print('Did not set the following values because they are un-required when set to the default value')
print(' - {}'.format('\n - '.join(skipped)))
# rta_mappings.add_rule_to_mapping_file(rule)
click.echo('Placeholder added to rule-mapping.yml')
return rule
+193
View File
@@ -0,0 +1,193 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Helper functions for managing rules in the repository."""
import copy
import io
import textwrap
from collections import OrderedDict
import toml
from .schema import NONFORMATTED_FIELDS
SQ = "'"
DQ = '"'
TRIPLE_SQ = SQ * 3
TRIPLE_DQ = DQ * 3
def cleanup_whitespace(val):
if isinstance(val, str):
return " ".join(line.strip() for line in val.strip().splitlines())
return val
def nested_normalize(d, skip_cleanup=False):
if isinstance(d, str):
return d if skip_cleanup else cleanup_whitespace(d)
elif isinstance(d, list):
return [nested_normalize(val) for val in d]
elif isinstance(d, dict):
for k, v in d.items():
if k == 'query':
# TODO: the linter still needs some work, but once up to par, uncomment to implement - kql.lint(v)
d.update({k: nested_normalize(v)})
elif k in NONFORMATTED_FIELDS:
# let these maintain newlines and whitespace for markdown support
d.update({k: nested_normalize(v, skip_cleanup=True)})
else:
d.update({k: nested_normalize(v)})
return d
else:
return d
def wrap_text(v, block_indent=0, join=False):
"""Block and indent a blob of text."""
v = ' '.join(v.split())
lines = textwrap.wrap(v, initial_indent=' ' * block_indent, subsequent_indent=' ' * block_indent, width=120,
break_long_words=False, break_on_hyphens=False)
lines = [line + '\n' for line in lines]
return lines if not join else ''.join(lines)
class NonformattedField(str):
"""Non-formatting class."""
class RuleTomlEncoder(toml.TomlEncoder):
"""Generate a pretty form of toml."""
def __init__(self, _dict=dict, preserve=False):
"""Create the encoder but override some default functions."""
super(RuleTomlEncoder, self).__init__(_dict, preserve)
self._old_dump_str = toml.TomlEncoder().dump_funcs[str]
self._old_dump_list = toml.TomlEncoder().dump_funcs[list]
self.dump_funcs[str] = self.dump_str
self.dump_funcs[type(u"")] = self.dump_str
self.dump_funcs[list] = self.dump_list
self.dump_funcs[NonformattedField] = self.dump_str
def dump_str(self, v):
"""Change the TOML representation to multi-line or single quote when logical."""
initial_newline = ['\n']
if isinstance(v, NonformattedField):
# first line break is not forced like other multiline string dumps
lines = v.splitlines(True)
initial_newline = []
else:
lines = wrap_text(v)
multiline = len(lines) > 1
raw = (multiline or (DQ in v and SQ not in v)) and TRIPLE_DQ not in v
if multiline:
if raw:
return "".join([TRIPLE_DQ] + initial_newline + lines + [TRIPLE_DQ])
else:
return "\n".join([TRIPLE_SQ] + [self._old_dump_str(line)[1:-1] for line in lines] + [TRIPLE_SQ])
elif raw:
return u"'{:s}'".format(lines[0])
return self._old_dump_str(v)
def _dump_flat_list(self, v):
"""A slightly tweaked version of original dump_list, removing trailing commas."""
if not v:
return "[]"
retval = "[" + str(self.dump_value(v[0])) + ","
for u in v[1:]:
retval += " " + str(self.dump_value(u)) + ","
retval = retval.rstrip(',') + "]"
return retval
def dump_list(self, v):
"""Dump a list more cleanly."""
if all([isinstance(d, str) for d in v]) and sum(len(d) + 3 for d in v) > 100:
dump = []
for item in v:
if len(item) > (120 - 4 - 3 - 3) and ' ' in item:
dump.append(' """\n{} """'.format(wrap_text(item, block_indent=4, join=True)))
else:
dump.append(' ' * 4 + self.dump_value(item))
return '[\n{},\n]'.format(',\n'.join(dump))
return self._dump_flat_list(v)
def toml_write(rule_contents, outfile=None):
"""Write rule in TOML."""
def write(text, nl=True):
if outfile:
outfile.write(text)
if nl:
outfile.write(u"\n")
else:
print(text, end='' if not nl else '\n')
encoder = RuleTomlEncoder()
contents = copy.deepcopy(rule_contents)
needs_close = False
def _do_write(_data, _contents):
query = None
if _data == 'rule':
# - We want to avoid the encoder for the query and instead use kql-lint.
# - Linting is done in rule.normalize() which is also called in rule.validate().
# - Until lint has tabbing, this is going to result in all queries being flattened with no wrapping,
# but will at least purge extraneous white space
query = contents['rule'].pop('query', '').strip()
tags = contents['rule'].get("tags", [])
if tags and isinstance(tags, list):
contents['rule']["tags"] = list(sorted(set(tags)))
top = OrderedDict()
bottom = OrderedDict()
for k in sorted(list(_contents)):
v = _contents.pop(k)
if isinstance(v, dict):
bottom[k] = OrderedDict(sorted(v.items()))
elif isinstance(v, list):
if any([isinstance(value, (dict, list)) for value in v]):
bottom[k] = v
else:
top[k] = v
elif k in NONFORMATTED_FIELDS:
top[k] = NonformattedField(v)
else:
top[k] = v
if query:
top.update({'query': "XXxXX"})
top.update(bottom)
top = toml.dumps(OrderedDict({data: top}), encoder=encoder)
# we want to preserve the query format, but want to modify it in the context of encoded dump
if query:
formatted_query = "\nquery = '''\n{}\n'''{}".format(query, '\n\n' if bottom else '')
top = top.replace('query = "XXxXX"', formatted_query)
write(top)
try:
if outfile and not isinstance(outfile, io.IOBase):
needs_close = True
outfile = open(outfile, 'w')
for data in ('metadata', 'rule'):
_contents = contents.get(data, {})
_do_write(data, _contents)
finally:
if needs_close:
outfile.close()
+192
View File
@@ -0,0 +1,192 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Load rule metadata transform between rule and api formats."""
import functools
import glob
import io
import os
import re
from collections import OrderedDict
import click
import pytoml
from .mappings import RtaMappings
from .rule import RULES_DIR, Rule
from .schema import get_schema
from .utils import get_path, cached
RTA_DIR = get_path("rta")
FILE_PATTERN = r'^([a-z0-9_])+\.(json|toml)$'
def mock_loader(f):
"""Mock rule loader."""
@functools.wraps(f)
def wrapped(*args, **kwargs):
try:
return f(*args, **kwargs)
finally:
load_rules.clear()
return wrapped
def reset():
"""Clear all rule caches."""
load_rule_files.clear()
load_rules.clear()
get_rule.clear()
filter_rules.clear()
@cached
def load_rule_files(verbose=True):
"""Load the rule YAML files, but without parsing the EQL query portion."""
file_lookup = {} # type: dict[str, dict]
if verbose:
print("Loading rules from {}".format(RULES_DIR))
for rule_file in sorted(glob.glob(os.path.join(RULES_DIR, '**', '*.toml'), recursive=True)):
try:
# use pytoml instead of toml because of annoying bugs
# https://github.com/uiri/toml/issues/152
# might also be worth looking at https://github.com/sdispater/tomlkit
with io.open(rule_file, "r", encoding="utf-8") as f:
file_lookup[rule_file] = pytoml.load(f)
except Exception:
print(u"Error loading {}".format(rule_file))
raise
if verbose:
print("Loaded {} rules".format(len(file_lookup)))
return file_lookup
@cached
def load_rules(file_lookup=None, verbose=True, error=True):
"""Load all the rules from toml files."""
file_lookup = file_lookup or load_rule_files(verbose=verbose)
failed = False
rules = [] # type: list[Rule]
errors = []
queries = []
rule_ids = set()
rule_names = set()
for rule_file, rule_contents in file_lookup.items():
try:
rule = Rule(rule_file, rule_contents)
if rule.id in rule_ids:
raise KeyError("Rule has duplicate ID to {}".format(next(r for r in rules if r.id == rule.id).path))
if rule.name in rule_names:
raise KeyError("Rule has duplicate name to {}".format(
next(r for r in rules if r.name == rule.name).path))
if rule.parsed_kql:
if rule.parsed_kql in queries:
raise KeyError("Rule has duplicate query with {}".format(
next(r for r in rules if r.parsed_kql == rule.parsed_kql).path))
queries.append(rule.parsed_kql)
if not re.match(FILE_PATTERN, os.path.basename(rule.path)):
raise ValueError(f"Rule {rule.path} does not meet rule name standard of {FILE_PATTERN}")
rules.append(rule)
rule_ids.add(rule.id)
rule_names.add(rule.name)
except Exception as e:
failed = True
err_msg = "Invalid rule file in {}\n{}".format(rule_file, click.style(e.args[0], fg='red'))
errors.append(err_msg)
if error:
print(err_msg)
raise e
if failed:
if verbose:
for e in errors:
print(e)
return OrderedDict([(rule.id, rule) for rule in sorted(rules, key=lambda r: r.name)])
@cached
def get_rule(rule_id=None, rule_name=None, file_name=None, verbose=True):
"""Get a rule based on its id."""
rules_lookup = load_rules(verbose=verbose)
if rule_id is not None:
return rules_lookup.get(rule_id)
for rule in rules_lookup.values(): # type: Rule
if rule.name == rule_name:
return rule
elif rule.path == file_name:
return rule
def get_rule_name(rule_id, verbose=True):
"""Get the name of a rule given the rule id."""
rule = get_rule(rule_id, verbose=verbose)
if rule:
return rule.name
def get_file_name(rule_id, verbose=True):
"""Get the file path that corresponds to a rule."""
rule = get_rule(rule_id, verbose=verbose)
if rule:
return rule.path
def get_rule_contents(rule_id, verbose=True):
"""Get the full contents for a rule_id."""
rule = get_rule(rule_id, verbose=verbose)
if rule:
return rule.contents
@cached
def filter_rules(rules, metadata_field, value):
"""Filter rules based on the metadata."""
return [rule for rule in rules if rule.metadata.get(metadata_field, {}) == value]
def get_production_rules():
"""Get rules with a maturity of production."""
return filter_rules(load_rules().values(), 'maturity', 'production')
def find_unneeded_defaults(rule):
"""Remove values that are not required in the schema which are set with default values."""
schema = get_schema(rule.contents['type'])
props = schema['properties']
unrequired_defaults = [p for p in props if p not in schema['required'] and props[p].get('default')]
default_matches = {p: rule.contents[p] for p in unrequired_defaults
if rule.contents.get(p) and rule.contents[p] == props[p]['default']}
return default_matches
rta_mappings = RtaMappings()
__all__ = (
"load_rules",
"get_file_name",
"get_production_rules",
"get_rule",
"filter_rules",
"get_rule_name",
"get_rule_contents",
"reset",
"rta_mappings"
)
+238
View File
@@ -0,0 +1,238 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Definitions for rule metadata and schemas."""
import time
import jsl
import jsonschema
from . import ecs
from .attack import TACTICS, TACTICS_MAP, TECHNIQUES, technique_lookup
UUID_PATTERN = r'[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}'
DATE_PATTERN = r'\d{4}/\d{2}/\d{2}'
VERSION_PATTERN = r'\d+\.\d+\.\d+'
RULE_LEVELS = ['recommended', 'aggressive']
MATURITY_LEVELS = ['development', 'testing', 'staged', 'production', 'deprecated']
OS_OPTIONS = ['windows', 'linux', 'macos', 'solaris'] # need to verify with ecs
INTERVAL_PATTERN = r'\d+[mshd]'
MITRE_URL_PATTERN = r'https://attack.mitre.org/{type}/T[A-Z0-9]+/'
NONFORMATTED_FIELDS = 'note',
# kibana/.../siem/server/lib/detection_engine/routes/schemas/add_prepackaged_rules_schema.ts
# /detection_engine/routes/schemas/schemas.ts
# rule_id is required here
# output_index is not allowed (and instead the space index must be used)
# immutable defaults to true instead of to false and if it is there can only be true
# enabled defaults to false instead of true
# version is a required field that must exist
MACHINE_LEARNING = 'machine_learning'
SAVED_QUERY = 'saved_query'
QUERY = 'query'
class FilterMetadata(jsl.Document):
"""Base class for siem rule meta filters."""
negate = jsl.BooleanField()
type = jsl.StringField()
key = jsl.StringField()
value = jsl.StringField()
disabled = jsl.BooleanField()
indexRefName = jsl.StringField()
alias = jsl.StringField() # null acceptable
params = jsl.DictField(properties={'query': jsl.StringField()})
class FilterQuery(jsl.Document):
"""Base class for siem rule query filters."""
match = jsl.DictField({
'event.action': jsl.DictField(properties={
'query': jsl.StringField(),
'type': jsl.StringField()
})
})
class FilterState(jsl.Document):
"""Base class for siem rule $state filters."""
store = jsl.StringField()
class FilterExists(jsl.Document):
"""Base class for siem rule $state filters."""
field = jsl.StringField()
class Filters(jsl.Document):
"""Schema for filters"""
exists = jsl.DocumentField(FilterExists)
meta = jsl.DocumentField(FilterMetadata)
state = jsl.DocumentField(FilterState, name='$state')
query = jsl.DocumentField(FilterQuery)
class ThreatTactic(jsl.Document):
"""Threat tactics."""
id = jsl.StringField(enum=TACTICS_MAP.values())
name = jsl.StringField(enum=TACTICS)
reference = jsl.StringField(MITRE_URL_PATTERN.format(type='tactics'))
class ThreatTechnique(jsl.Document):
"""Threat tactics."""
id = jsl.StringField(enum=list(technique_lookup))
name = jsl.StringField(enum=TECHNIQUES)
reference = jsl.StringField(MITRE_URL_PATTERN.format(type='techniques'))
class Threat(jsl.Document):
"""Threat framework mapping such as MITRE ATT&CK."""
framework = jsl.StringField(default='MITRE ATT&CK', required=True)
tactic = jsl.DocumentField(ThreatTactic, required=True)
technique = jsl.ArrayField(jsl.DocumentField(ThreatTechnique), required=True)
class SiemRuleApiSchema(jsl.Document):
"""Schema for siem rule in API format."""
actions = jsl.ArrayField(required=False)
author = jsl.ArrayField(jsl.StringField(default="Elastic"), required=True, min_items=1)
description = jsl.StringField(required=True)
# api defaults to false if blank
enabled = jsl.BooleanField(default=False, required=False)
exceptions_list = jsl.ArrayField(required=False)
# _ required since `from` is a reserved word in python
from_ = jsl.StringField(required=False, default='now-6m', name='from')
false_positives = jsl.ArrayField(jsl.StringField(), required=False)
filters = jsl.ArrayField(jsl.DocumentField(Filters))
interval = jsl.StringField(pattern=INTERVAL_PATTERN, default='5m', required=False)
license = jsl.StringField(required=True, default="Elastic License")
max_signals = jsl.IntField(minimum=1, required=False, default=100) # cap a max?
meta = jsl.DictField(required=False)
name = jsl.StringField(required=True)
note = jsl.StringField(required=False)
# output_index = jsl.StringField(required=False) # this is NOT allowed!
references = jsl.ArrayField(jsl.StringField(), required=False)
risk_score = jsl.IntField(minimum=0, maximum=100, required=True, default=21)
rule_id = jsl.StringField(pattern=UUID_PATTERN, required=True)
severity = jsl.StringField(enum=['low', 'medium', 'high', 'critical'], default='low', required=True)
# saved_id - type must be 'saved_query' to allow this or else it is forbidden
tags = jsl.ArrayField(jsl.StringField(), required=False)
throttle = jsl.StringField(required=False)
timeline_id = jsl.StringField(required=False)
timeline_title = jsl.StringField(required=False)
to = jsl.StringField(required=False, default='now')
# require this to be always validated with a role
# type = jsl.StringField(enum=[MACHINE_LEARNING, QUERY, SAVED_QUERY], required=True)
threat = jsl.ArrayField(jsl.DocumentField(Threat), required=False, min_items=1)
with jsl.Scope(MACHINE_LEARNING) as ml_scope:
ml_scope.anomaly_threshold = jsl.IntField(required=True, minimum=0)
ml_scope.machine_learning_job_id = jsl.StringField(required=True)
ml_scope.type = jsl.StringField(enum=[MACHINE_LEARNING], required=True, default=MACHINE_LEARNING)
with jsl.Scope(QUERY) as query_scope:
query_scope.index = jsl.ArrayField(jsl.StringField(), required=False)
# this is not required per the API but we will enforce it here
query_scope.language = jsl.StringField(enum=['kuery', 'lucene'], required=True, default='kuery')
query_scope.query = jsl.StringField(required=True)
query_scope.type = jsl.StringField(enum=[QUERY], required=True, default=QUERY)
with jsl.Scope(SAVED_QUERY) as saved_id_scope:
saved_id_scope.index = jsl.ArrayField(jsl.StringField(), required=False)
saved_id_scope.saved_id = jsl.StringField(required=True)
saved_id_scope.type = jsl.StringField(enum=[SAVED_QUERY], required=True, default=SAVED_QUERY)
class VersionedApiSchema(SiemRuleApiSchema):
"""Schema for siem rule in API format with version."""
version = jsl.IntField(minimum=1, default=1, required=True)
class SiemRuleTomlMetadata(jsl.Document):
"""Schema for siem rule toml metadata."""
creation_date = jsl.StringField(required=True, pattern=DATE_PATTERN, default=time.strftime('%Y/%m/%d'))
# added to query with rule.optimize()
# rule validated against each ecs schema contained
ecs_version = jsl.ArrayField(
jsl.StringField(pattern=VERSION_PATTERN, required=True, default=ecs.get_max_version()), required=True)
maturity = jsl.StringField(enum=MATURITY_LEVELS, default='development', required=True)
# if present, add to query
os_type_list = jsl.ArrayField(jsl.StringField(enum=OS_OPTIONS), required=False)
related_endpoint_rules = jsl.ArrayField(jsl.ArrayField(jsl.StringField(), min_items=2, max_items=2),
required=False)
updated_date = jsl.StringField(required=True, pattern=DATE_PATTERN, default=time.strftime('%Y/%m/%d'))
class SiemRuleTomlSchema(jsl.Document):
"""Schema for siem rule in management toml format."""
metadata = jsl.DocumentField(SiemRuleTomlMetadata)
rule = jsl.DocumentField(SiemRuleApiSchema)
class Package(jsl.Document):
"""Schema for siem rule staging."""
class MappingCount(jsl.Document):
"""Mapping count schema."""
count = jsl.IntField(minimum=0, required=True)
rta_name = jsl.StringField(pattern=r'[a-zA-Z-_]+', required=True)
rule_name = jsl.StringField(required=True)
sources = jsl.ArrayField(jsl.StringField(), min_items=1)
cached_schemas = {}
def get_schema(role, as_rule=False, versioned=False):
"""Get applicable schema by role type and rule format."""
if (role, as_rule, versioned) not in cached_schemas:
if versioned:
cls = VersionedApiSchema
else:
cls = SiemRuleTomlSchema if as_rule else SiemRuleApiSchema
cached_schemas[(role, as_rule, versioned)] = cls.get_schema(ordered=True, role=role)
return cached_schemas[(role, as_rule, versioned)]
def schema_validate(contents, as_rule=False, versioned=False):
"""Validate against all schemas until first hit."""
assert isinstance(contents, dict)
role = contents.get('rule', {}).get('type') if as_rule else contents.get('type')
if not role:
raise ValueError('Missing rule type!')
return jsonschema.validate(contents, get_schema(role, as_rule, versioned))
metadata_schema = SiemRuleTomlMetadata.get_schema(ordered=True)
package_schema = Package.get_schema(ordered=True)
mapping_schema = MappingCount.get_schema(ordered=True)
def validate_rta_mapping(mapping):
"""Validate the RTA mapping."""
jsonschema.validate(mapping, mapping_schema)
+25
View File
@@ -0,0 +1,25 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Helper functionality for comparing semantic versions."""
import re
class Version(tuple):
def __new__(cls, version):
if not isinstance(version, (int, list, tuple)):
version = tuple(int(a) if a.isdigit() else a for a in re.split(r'[.-]', version))
return tuple.__new__(cls, version)
def bump(self):
"""Increment the version."""
versions = list(self)
versions[-1] += 1
return Version(versions)
def __str__(self):
"""Convert back to a string."""
return ".".join(str(dig) for dig in self)
+186
View File
@@ -0,0 +1,186 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Util functions."""
import contextlib
import functools
import gzip
import io
import json
import os
import time
import zipfile
from datetime import datetime
import kql
import eql.utils
from eql.utils import stream_json_lines
CURR_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(CURR_DIR)
ETC_DIR = os.path.join(ROOT_DIR, "etc")
def get_json_iter(f):
"""Get an iterator over a JSON file."""
first = f.read(2)
f.seek(0)
if first[0] == '[' or first == "{\n":
return json.load(f)
else:
data = list(stream_json_lines(f))
return data
def get_path(*paths):
"""Get a file by relative path."""
return os.path.join(ROOT_DIR, *paths)
def get_etc_path(*paths):
"""Load a file from the etc/ folder."""
return os.path.join(ETC_DIR, *paths)
def get_etc_file(name, mode="r"):
"""Load a file from the etc/ folder."""
with open(get_etc_path(name), mode) as f:
return f.read()
def load_etc_dump(*path):
"""Load a json/yml/toml file from the etc/ folder."""
return eql.utils.load_dump(get_etc_path(*path))
def save_etc_dump(contents, *path):
"""Load a json/yml/toml file from the etc/ folder."""
return eql.utils.save_dump(contents, get_etc_path(*path))
def get_ecs_fields(endgame_field):
ecs_mapping = load_etc_dump('ecs_mappings.json')
return ecs_mapping.get(endgame_field)
def save_gzip(contents):
gz_file = io.BytesIO()
with gzip.GzipFile(mode="w", fileobj=gz_file) as f:
if not isinstance(contents, bytes):
contents = contents.encode("utf8")
f.write(contents)
return gz_file.getvalue()
@contextlib.contextmanager
def unzip(contents): # type: (bytes) -> zipfile.ZipFile
"""Get zipped contents."""
zipped = io.BytesIO(contents)
archive = zipfile.ZipFile(zipped, mode="r")
try:
yield archive
finally:
archive.close()
def unzip_and_save(contents, path, member=None, verbose=True):
"""Save unzipped from raw zipped contents."""
with unzip(contents) as archive:
if member:
archive.extract(member, path)
else:
archive.extractall(path)
if verbose:
name_list = archive.namelist()[member] if not member else archive.namelist()
print('Saved files to {}: \n\t- {}'.format(path, '\n\t- '.join(name_list)))
def event_sort(events, timestamp='@timestamp', date_format='%Y-%m-%dT%H:%M:%S.%f%z', asc=True):
"""Sort events from elasticsearch by timestamp."""
def _event_sort(event):
t = event[timestamp]
return (time.mktime(time.strptime(t, date_format)) + int(t.split('.')[-1][:-1]) / 1000) * 1000
return sorted(events, key=_event_sort, reverse=not asc)
def combine_sources(*sources): # type: (list[list]) -> list
"""Combine lists of events from multiple sources."""
combined = []
for source in sources:
combined.extend(source.copy())
return event_sort(combined)
def evaluate(rule, events):
"""Evaluate a query against events."""
evaluator = kql.get_evaluator(kql.parse(rule.query))
filtered = list(filter(evaluator, events))
return filtered
def unix_time_to_formatted(timestamp): # type: (int|str) -> str
"""Converts unix time in seconds or milliseconds to the default format."""
if isinstance(timestamp, (int, float)):
if timestamp > 2 ** 32:
timestamp = round(timestamp / 1000, 3)
return datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'
def normalize_timing_and_sort(events, timestamp='@timestamp', asc=True):
"""Normalize timestamp formats and sort events."""
for event in events:
_timestamp = event[timestamp]
if not isinstance(_timestamp, str):
event[timestamp] = unix_time_to_formatted(_timestamp)
return event_sort(events, timestamp=timestamp, asc=asc)
def freeze(obj):
"""Helper function to make mutable objects immutable and hashable."""
if isinstance(obj, (list, tuple)):
return tuple(freeze(o) for o in obj)
elif isinstance(obj, dict):
return freeze(list(sorted(obj.items())))
else:
return obj
_cache = {}
def cached(f):
"""Helper function to memoize functions."""
func_key = id(f)
@functools.wraps(f)
def wrapped(*args, **kwargs):
_cache.setdefault(func_key, {})
cache_key = freeze(args), freeze(kwargs)
if cache_key not in _cache[func_key]:
_cache[func_key][cache_key] = f(*args, **kwargs)
return _cache[func_key][cache_key]
def clear():
_cache.pop(func_key, None)
wrapped.clear = clear
return wrapped
def clear_caches():
_cache.clear()