Files
sigma-rules/detection_rules/utils.py
T

456 lines
13 KiB
Python
Raw Normal View History

2020-06-29 23:17:38 -06:00
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
2021-03-03 22:12:11 -09:00
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
2020-06-29 23:17:38 -06:00
"""Util functions."""
2021-03-24 10:24:32 -06:00
import base64
2020-06-29 23:17:38 -06:00
import contextlib
import functools
import glob
2020-06-29 23:17:38 -06:00
import gzip
2021-03-24 10:24:32 -06:00
import hashlib
2020-06-29 23:17:38 -06:00
import io
import json
import os
import re
import shutil
import subprocess
2020-06-29 23:17:38 -06:00
import zipfile
2021-03-24 10:24:32 -06:00
from dataclasses import is_dataclass, astuple
from datetime import datetime, date
from pathlib import Path
from typing import Dict, Union, Optional, Callable
2023-03-28 07:17:50 -06:00
from string import Template
2020-06-29 23:17:38 -06:00
import click
2021-09-10 10:06:04 -08:00
import pytoml
2020-06-29 23:17:38 -06:00
import eql.utils
from eql.utils import load_dump, stream_json_lines
2020-06-29 23:17:38 -06:00
2021-03-24 10:24:32 -06:00
import kql
CURR_DIR = Path(__file__).resolve().parent
ROOT_DIR = CURR_DIR.parent
ETC_DIR = ROOT_DIR / "detection_rules" / "etc"
INTEGRATION_RULE_DIR = ROOT_DIR / "rules" / "integrations"
2020-06-29 23:17:38 -06:00
2021-03-24 10:24:32 -06:00
class NonelessDict(dict):
"""Wrapper around dict that doesn't populate None values."""
def __setitem__(self, key, value):
if value is not None:
dict.__setitem__(self, key, value)
class DateTimeEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (date, datetime)):
return obj.isoformat()
2021-03-24 10:24:32 -06:00
marshmallow_schemas = {}
def gopath() -> Optional[str]:
"""Retrieve $GOPATH."""
env_path = os.getenv("GOPATH")
if env_path:
return env_path
go_bin = shutil.which("go")
if go_bin:
output = subprocess.check_output([go_bin, "env"], encoding="utf-8").splitlines()
for line in output:
if line.startswith("GOPATH="):
return line[len("GOPATH="):].strip('"')
2021-03-24 10:24:32 -06:00
def dict_hash(obj: dict) -> str:
"""Hash a dictionary deterministically."""
raw_bytes = base64.b64encode(json.dumps(obj, sort_keys=True).encode('utf-8'))
return hashlib.sha256(raw_bytes).hexdigest()
2020-06-29 23:17:38 -06:00
def get_json_iter(f):
"""Get an iterator over a JSON file."""
first = f.read(2)
f.seek(0)
if first[0] == '[' or first == "{\n":
return json.load(f)
else:
data = list(stream_json_lines(f))
return data
def get_path(*paths) -> Path:
2020-06-29 23:17:38 -06:00
"""Get a file by relative path."""
return ROOT_DIR.joinpath(*paths)
2020-06-29 23:17:38 -06:00
def get_etc_path(*paths) -> Path:
2022-05-02 10:11:21 -04:00
"""Load a file from the detection_rules/etc/ folder."""
return ETC_DIR.joinpath(*paths)
2020-06-29 23:17:38 -06:00
def get_etc_glob_path(*patterns) -> list:
2022-05-02 10:11:21 -04:00
"""Load a file from the detection_rules/etc/ folder."""
pattern = os.path.join(*patterns)
return glob.glob(str(ETC_DIR / pattern))
2020-06-29 23:17:38 -06:00
def get_etc_file(name, mode="r"):
2022-05-02 10:11:21 -04:00
"""Load a file from the detection_rules/etc/ folder."""
2020-06-29 23:17:38 -06:00
with open(get_etc_path(name), mode) as f:
return f.read()
def load_etc_dump(*path):
2022-05-02 10:11:21 -04:00
"""Load a json/yml/toml file from the detection_rules/etc/ folder."""
return eql.utils.load_dump(str(get_etc_path(*path)))
2020-06-29 23:17:38 -06:00
def save_etc_dump(contents, *path, **kwargs):
"""Save a json/yml/toml file from the detection_rules/etc/ folder."""
path = str(get_etc_path(*path))
_, ext = os.path.splitext(path)
sort_keys = kwargs.pop('sort_keys', True)
indent = kwargs.pop('indent', 2)
if ext == ".json":
with open(path, "wt") as f:
json.dump(contents, f, cls=DateTimeEncoder, sort_keys=sort_keys, indent=indent, **kwargs)
else:
return eql.utils.save_dump(contents, path)
2020-06-29 23:17:38 -06:00
def gzip_compress(contents) -> bytes:
2020-06-29 23:17:38 -06:00
gz_file = io.BytesIO()
with gzip.GzipFile(mode="w", fileobj=gz_file) as f:
if not isinstance(contents, bytes):
contents = contents.encode("utf8")
f.write(contents)
return gz_file.getvalue()
def read_gzip(path):
with gzip.GzipFile(path, mode='r') as gz:
return gz.read().decode("utf8")
2020-06-29 23:17:38 -06:00
@contextlib.contextmanager
def unzip(contents): # type: (bytes) -> zipfile.ZipFile
"""Get zipped contents."""
zipped = io.BytesIO(contents)
archive = zipfile.ZipFile(zipped, mode="r")
try:
yield archive
finally:
archive.close()
def unzip_and_save(contents, path, member=None, verbose=True):
"""Save unzipped from raw zipped contents."""
with unzip(contents) as archive:
if member:
archive.extract(member, path)
else:
archive.extractall(path)
if verbose:
name_list = archive.namelist()[member] if not member else archive.namelist()
print('Saved files to {}: \n\t- {}'.format(path, '\n\t- '.join(name_list)))
def unzip_to_dict(zipped: zipfile.ZipFile, load_json=True) -> Dict[str, Union[dict, str]]:
"""Unzip and load contents to dict with filenames as keys."""
bundle = {}
for filename in zipped.namelist():
if filename.endswith('/'):
continue
fp = Path(filename)
contents = zipped.read(filename)
if load_json and fp.suffix == '.json':
contents = json.loads(contents)
bundle[fp.name] = contents
return bundle
2020-06-29 23:17:38 -06:00
def event_sort(events, timestamp='@timestamp', date_format='%Y-%m-%dT%H:%M:%S.%f%z', asc=True):
"""Sort events from elasticsearch by timestamp."""
2021-03-24 10:24:32 -06:00
2023-06-12 20:03:33 +00:00
def round_microseconds(t: str) -> str:
"""Rounds the microseconds part of a timestamp string to 6 decimal places."""
if not t:
# Return early if the timestamp string is empty
return t
parts = t.split('.')
if len(parts) == 2:
# Remove trailing "Z" from microseconds part
micro_seconds = parts[1].rstrip("Z")
if len(micro_seconds) > 6:
# If the microseconds part has more than 6 digits
# Convert the microseconds part to a float and round to 6 decimal places
rounded_micro_seconds = round(float(f"0.{micro_seconds}"), 6)
# Format the rounded value to always have 6 decimal places
# Reconstruct the timestamp string with the rounded microseconds part
formatted_micro_seconds = f'{rounded_micro_seconds:0.6f}'.split(".")[-1]
t = f"{parts[0]}.{formatted_micro_seconds}Z"
return t
def _event_sort(event: dict) -> datetime:
"""Calculates the sort key for an event as a datetime object."""
2023-06-12 20:03:33 +00:00
t = round_microseconds(event[timestamp])
# Return the timestamp as a datetime object for comparison
return datetime.strptime(t, date_format)
2020-06-29 23:17:38 -06:00
return sorted(events, key=_event_sort, reverse=not asc)
def combine_sources(*sources): # type: (list[list]) -> list
"""Combine lists of events from multiple sources."""
combined = []
for source in sources:
combined.extend(source.copy())
return event_sort(combined)
def convert_time_span(span: str) -> int:
"""Convert time span in Date Math to value in milliseconds."""
amount = int("".join(char for char in span if char.isdigit()))
unit = eql.ast.TimeUnit("".join(char for char in span if char.isalpha()))
return eql.ast.TimeRange(amount, unit).as_milliseconds()
2020-06-29 23:17:38 -06:00
def evaluate(rule, events):
"""Evaluate a query against events."""
2024-04-04 20:27:14 -04:00
evaluator = kql.get_evaluator(kql.parse(rule.query))
2020-06-29 23:17:38 -06:00
filtered = list(filter(evaluator, events))
return filtered
def unix_time_to_formatted(timestamp): # type: (int|str) -> str
"""Converts unix time in seconds or milliseconds to the default format."""
if isinstance(timestamp, (int, float)):
if timestamp > 2 ** 32:
timestamp = round(timestamp / 1000, 3)
return datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'
def normalize_timing_and_sort(events, timestamp='@timestamp', asc=True):
"""Normalize timestamp formats and sort events."""
for event in events:
_timestamp = event[timestamp]
if not isinstance(_timestamp, str):
event[timestamp] = unix_time_to_formatted(_timestamp)
return event_sort(events, timestamp=timestamp, asc=asc)
def freeze(obj):
"""Helper function to make mutable objects immutable and hashable."""
2021-03-24 10:24:32 -06:00
if not isinstance(obj, type) and is_dataclass(obj):
obj = astuple(obj)
2020-06-29 23:17:38 -06:00
if isinstance(obj, (list, tuple)):
return tuple(freeze(o) for o in obj)
elif isinstance(obj, dict):
2021-03-24 10:24:32 -06:00
return freeze(sorted(obj.items()))
2020-06-29 23:17:38 -06:00
else:
return obj
_cache = {}
def cached(f):
"""Helper function to memoize functions."""
func_key = id(f)
@functools.wraps(f)
def wrapped(*args, **kwargs):
_cache.setdefault(func_key, {})
cache_key = freeze(args), freeze(kwargs)
if cache_key not in _cache[func_key]:
_cache[func_key][cache_key] = f(*args, **kwargs)
return _cache[func_key][cache_key]
def clear():
_cache.pop(func_key, None)
wrapped.clear = clear
return wrapped
def clear_caches():
_cache.clear()
def rulename_to_filename(name: str, tactic_name: str = None, ext: str = '.toml') -> str:
"""Convert a rule name to a filename."""
name = re.sub(r'[^_a-z0-9]+', '_', name.strip().lower()).strip('_')
if tactic_name:
pre = rulename_to_filename(name=tactic_name, ext='')
name = f'{pre}_{name}'
return name + ext or ''
2021-09-10 10:06:04 -08:00
def load_rule_contents(rule_file: Path, single_only=False) -> list:
"""Load a rule file from multiple formats."""
_, extension = os.path.splitext(rule_file)
2021-09-10 10:06:04 -08:00
raw_text = rule_file.read_text()
if extension in ('.ndjson', '.jsonl'):
# kibana exported rule object is ndjson with the export metadata on the last line
2021-09-10 10:06:04 -08:00
contents = [json.loads(line) for line in raw_text.splitlines()]
2021-09-10 10:06:04 -08:00
if len(contents) > 1 and 'exported_count' in contents[-1]:
contents.pop(-1)
2021-09-10 10:06:04 -08:00
if single_only and len(contents) > 1:
raise ValueError('Multiple rules not allowed')
2021-09-10 10:06:04 -08:00
return contents or [{}]
elif extension == '.toml':
rule = pytoml.loads(raw_text)
elif extension.lower() in ('yaml', 'yml'):
rule = load_dump(str(rule_file))
else:
return []
2021-09-10 10:06:04 -08:00
if isinstance(rule, dict):
return [rule]
elif isinstance(rule, list):
return rule
else:
raise ValueError(f"Expected a list or dictionary in {rule_file}")
def format_command_options(ctx):
"""Echo options for a click command."""
formatter = ctx.make_formatter()
opts = []
for param in ctx.command.get_params(ctx):
if param.name == 'help':
continue
rv = param.get_help_record(ctx)
if rv is not None:
opts.append(rv)
if opts:
with formatter.section('Options'):
formatter.write_dl(opts)
return formatter.getvalue()
def make_git(*prefix_args) -> Optional[Callable]:
git_exe = shutil.which("git")
prefix_args = [str(arg) for arg in prefix_args]
if not git_exe:
click.secho("Unable to find git", err=True, fg="red")
ctx = click.get_current_context(silent=True)
if ctx is not None:
ctx.exit(1)
return
2021-08-05 01:15:39 -06:00
def git(*args, print_output=False):
nonlocal prefix_args
if '-C' not in prefix_args:
prefix_args = ['-C', get_path()] + prefix_args
full_args = [git_exe] + prefix_args + [str(arg) for arg in args]
2021-08-05 01:15:39 -06:00
if print_output:
return subprocess.check_call(full_args)
return subprocess.check_output(full_args, encoding="utf-8").rstrip()
return git
2021-08-05 01:15:39 -06:00
def git(*args, **kwargs):
"""Find and run a one-off Git command."""
return make_git()(*args, **kwargs)
def add_params(*params):
"""Add parameters to a click command."""
2021-03-24 10:24:32 -06:00
def decorator(f):
if not hasattr(f, '__click_params__'):
f.__click_params__ = []
f.__click_params__.extend(params)
return f
return decorator
class Ndjson(list):
"""Wrapper for ndjson data."""
def to_string(self, sort_keys: bool = False):
"""Format contents list to ndjson string."""
return '\n'.join(json.dumps(c, sort_keys=sort_keys) for c in self) + '\n'
@classmethod
def from_string(cls, ndjson_string: str, **kwargs):
"""Load ndjson string to a list."""
contents = [json.loads(line, **kwargs) for line in ndjson_string.strip().splitlines()]
return Ndjson(contents)
def dump(self, filename: Path, sort_keys=False):
"""Save contents to an ndjson file."""
filename.write_text(self.to_string(sort_keys=sort_keys))
@classmethod
def load(cls, filename: Path, **kwargs):
"""Load content from an ndjson file."""
return cls.from_string(filename.read_text(), **kwargs)
2023-03-28 07:17:50 -06:00
class PatchedTemplate(Template):
"""String template with updated methods from future versions."""
def get_identifiers(self):
"""Returns a list of the valid identifiers in the template, in the order they first appear, ignoring any
invalid identifiers."""
# https://github.com/python/cpython/blob/3b4f8fc83dcea1a9d0bc5bd33592e5a3da41fa71/Lib/string.py#LL157-L171C19
ids = []
for mo in self.pattern.finditer(self.template):
named = mo.group('named') or mo.group('braced')
if named is not None and named not in ids:
# add a named group only the first time it appears
ids.append(named)
elif named is None and mo.group('invalid') is None and mo.group('escaped') is None:
# If all the groups are None, there must be
# another group we're not expecting
raise ValueError('Unrecognized named group in pattern',
self.pattern)
return ids