Files
sigma-rules/detection_rules/rule.py
T
Ross Wolf 8ee1b2ffd4 Fix the version lock update code (#1064)
* Fix the version lock update code
* Add Rule.lock_info() method
2021-03-25 14:48:31 -06:00

453 lines
15 KiB
Python

# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
"""Rule object."""
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Union, Optional, List, Any
from uuid import uuid4
import eql
from marshmallow import validates_schema
import kql
from . import ecs, beats, utils
from .mixins import MarshmallowDataclassMixin
from .rule_formatter import toml_write, nested_normalize
from .schemas import downgrade
from .schemas import definitions
from .utils import get_path, cached
RULES_DIR = get_path("rules")
_META_SCHEMA_REQ_DEFAULTS = {}
@dataclass(frozen=True)
class RuleMeta(MarshmallowDataclassMixin):
"""Data stored in a rule's [metadata] section of TOML."""
creation_date: definitions.Date
updated_date: definitions.Date
deprecation_date: Optional[definitions.Date]
# Optional fields
beats_version: Optional[definitions.SemVer]
ecs_versions: Optional[List[definitions.SemVer]]
comments: Optional[str]
maturity: Optional[definitions.Maturity]
os_type_list: Optional[List[definitions.OSType]]
query_schema_validation: Optional[bool]
related_endpoint_rules: Optional[List[str]]
@dataclass(frozen=True)
class BaseThreatEntry:
id: str
name: str
reference: str
@dataclass(frozen=True)
class SubTechnique(BaseThreatEntry):
"""Mapping to threat subtechnique."""
reference: definitions.SubTechniqueURL
@dataclass(frozen=True)
class Technique(BaseThreatEntry):
"""Mapping to threat subtechnique."""
# subtechniques are stored at threat[].technique.subtechnique[]
reference: definitions.TechniqueURL
subtechnique: Optional[List[SubTechnique]]
@dataclass(frozen=True)
class Tactic(BaseThreatEntry):
"""Mapping to a threat tactic."""
reference: definitions.TacticURL
@dataclass(frozen=True)
class ThreatMapping(MarshmallowDataclassMixin):
"""Mapping to a threat framework."""
framework: Literal["MITRE ATT&CK"]
tactic: Tactic
technique: Optional[List[Technique]]
@staticmethod
def flatten(threat_mappings: Optional[List]) -> 'FlatThreatMapping':
"""Get flat lists of tactic and technique info."""
tactic_names = []
tactic_ids = []
technique_ids = set()
technique_names = set()
sub_technique_ids = set()
sub_technique_names = set()
for entry in (threat_mappings or []):
tactic_names.append(entry.tactic.name)
tactic_ids.append(entry.tactic.id)
for technique in (entry.technique or []):
technique_names.add(technique.name)
technique_ids.add(technique.id)
for subtechnique in (technique.subtechnique or []):
sub_technique_ids.update(subtechnique.id)
sub_technique_names.update(subtechnique.name)
return FlatThreatMapping(
tactic_names=sorted(tactic_names),
tactic_ids=sorted(tactic_ids),
technique_names=sorted(technique_names),
technique_ids=sorted(technique_ids),
sub_technique_names=sorted(sub_technique_names),
sub_technique_ids=sorted(sub_technique_ids)
)
@dataclass(frozen=True)
class RiskScoreMapping(MarshmallowDataclassMixin):
field: str
operator: Optional[definitions.Operator]
value: Optional[str]
@dataclass(frozen=True)
class SeverityMapping(MarshmallowDataclassMixin):
field: str
operator: Optional[definitions.Operator]
value: Optional[str]
severity: Optional[str]
@dataclass(frozen=True)
class FlatThreatMapping(MarshmallowDataclassMixin):
tactic_names: List[str]
tactic_ids: List[str]
technique_names: List[str]
technique_ids: List[str]
sub_technique_names: List[str]
sub_technique_ids: List[str]
@dataclass(frozen=True)
class BaseRuleData(MarshmallowDataclassMixin):
actions: Optional[list]
author: List[str]
building_block_type: Optional[str]
description: Optional[str]
enabled: Optional[bool]
exceptions_list: Optional[list]
license: Optional[str]
false_positives: Optional[List[str]]
filters: Optional[List[dict]]
# trailing `_` required since `from` is a reserved word in python
from_: Optional[str] = field(metadata=dict(data_key="from"))
interval: Optional[definitions.Interval]
max_signals: Optional[definitions.MaxSignals]
meta: Optional[dict]
name: str
note: Optional[definitions.Markdown]
# can we remove this comment?
# explicitly NOT allowed!
# output_index: Optional[str]
references: Optional[List[str]]
risk_score: definitions.RiskScore
risk_score_mapping: Optional[List[RiskScoreMapping]]
rule_id: definitions.UUIDString
rule_name_override: Optional[str]
severity_mapping: Optional[List[SeverityMapping]]
severity: definitions.Severity
tags: Optional[List[str]]
throttle: Optional[str]
timeline_id: Optional[str]
timeline_title: Optional[str]
timestamp_override: Optional[str]
to: Optional[str]
type: Literal[definitions.RuleType]
threat: Optional[List[ThreatMapping]]
@dataclass(frozen=True)
class BaseQueryRuleData(BaseRuleData):
"""Specific fields for query event types."""
type: Literal["query"]
index: Optional[List[str]]
query: str
language: str
@property
def parsed_query(self) -> Optional[object]:
return None
@dataclass(frozen=True)
class KQLRuleData(BaseQueryRuleData):
"""Specific fields for query event types."""
language: Literal["kuery"]
@property
def parsed_query(self) -> kql.ast.Expression:
return kql.parse(self.query)
@property
def unique_fields(self):
return list(set(str(f) for f in self.parsed_query if isinstance(f, kql.ast.Field)))
def to_eql(self) -> eql.ast.Expression:
return kql.to_eql(self.query)
def validate_query(self, beats_version: str, ecs_versions: List[str]):
"""Static method to validate the query, called from the parent which contains [metadata] information."""
indexes = self.index or []
parsed = self.parsed_query
beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index]
beat_schema = beats.get_schema_from_kql(parsed, beat_types, version=beats_version) if beat_types else None
if not ecs_versions:
kql.parse(self.query, schema=ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema))
else:
for version in ecs_versions:
schema = ecs.get_kql_schema(version=version, indexes=indexes, beat_schema=beat_schema)
try:
kql.parse(self.query, schema=schema)
except kql.KqlParseError as exc:
message = exc.error_msg
trailer = None
if "Unknown field" in message and beat_types:
trailer = "\nTry adding event.module or event.dataset to specify beats module"
raise kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source,
len(exc.caret.lstrip()), trailer=trailer) from None
@dataclass(frozen=True)
class LuceneRuleData(BaseQueryRuleData):
"""Specific fields for query event types."""
language: Literal["lucene"]
@dataclass(frozen=True)
class MachineLearningRuleData(BaseRuleData):
type: Literal["machine_learning"]
anomaly_threshold: int
machine_learning_job_id: str
@dataclass(frozen=True)
class ThresholdQueryRuleData(BaseQueryRuleData):
"""Specific fields for query event types."""
@dataclass(frozen=True)
class ThresholdMapping(MarshmallowDataclassMixin):
@dataclass(frozen=True)
class ThresholdCardinality:
field: str
value: definitions.ThresholdValue
field: List[str]
value: definitions.ThresholdValue
cardinality: Optional[ThresholdCardinality]
type: Literal["threshold"]
language: Literal["kuery", "lucene"]
threshold: ThresholdMapping
@dataclass(frozen=True)
class EQLRuleData(BaseQueryRuleData):
"""EQL rules are a special case of query rules."""
type: Literal["eql"]
@property
def parsed_query(self) -> kql.ast.Expression:
with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
return eql.parse_query(self.query)
@property
def unique_fields(self):
return list(set(str(f) for f in self.parsed_query if isinstance(f, eql.ast.Field)))
def validate_query(self, beats_version: str, ecs_versions: List[str]):
"""Validate an EQL query while checking TOMLRule."""
# TODO: remove once py-eql supports ipv6 for cidrmatch
# Or, unregister the cidrMatch function and replace it with one that doesn't validate against strict IPv4
with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
parsed = eql.parse_query(self.query)
beat_types = [index.split("-")[0] for index in self.index or [] if "beat-*" in index]
beat_schema = beats.get_schema_from_eql(parsed, beat_types, version=beats_version) if beat_types else None
for version in ecs_versions:
schema = ecs.get_kql_schema(indexes=self.index or [], beat_schema=beat_schema, version=version)
try:
# TODO: switch to custom cidrmatch that allows ipv6
with ecs.KqlSchema2Eql(schema), eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
eql.parse_query(self.query)
except eql.EqlTypeMismatchError:
raise
except eql.EqlParseError as exc:
message = exc.error_msg
trailer = None
if "Unknown field" in message and beat_types:
trailer = "\nTry adding event.module or event.dataset to specify beats module"
raise exc.__class__(exc.error_msg, exc.line, exc.column, exc.source,
len(exc.caret.lstrip()), trailer=trailer) from None
# All of the possible rule types
AnyRuleData = Union[KQLRuleData, LuceneRuleData, MachineLearningRuleData, ThresholdQueryRuleData, EQLRuleData]
@dataclass(frozen=True)
class TOMLRuleContents(MarshmallowDataclassMixin):
"""Rule object which maps directly to the TOML layout."""
metadata: RuleMeta
data: AnyRuleData = field(metadata=dict(data_key="rule"))
@property
def id(self) -> definitions.UUIDString:
return self.data.rule_id
@property
def name(self) -> str:
return self.data.name
def lock_info(self) -> dict:
return {"rule_name": self.name, "sha256": self.sha256(), "version": self.autobumped_version}
@property
def is_dirty(self) -> Optional[bool]:
"""Determine if the rule has changed since its version was locked."""
from .packaging import load_versions
rules_versions = load_versions()
if self.id in rules_versions:
version_info = rules_versions[self.id]
existing_sha256: str = version_info['sha256']
return existing_sha256 != self.sha256()
@property
def latest_version(self) -> Optional[int]:
"""Retrieve the latest known version of the rule."""
from .packaging import load_versions
rules_versions = load_versions()
if self.id in rules_versions:
version_info = rules_versions[self.id]
version = version_info['version']
return version
@property
def autobumped_version(self) -> Optional[int]:
"""Retrieve the current version of the rule, accounting for automatic increments."""
version = self.latest_version
if version is None:
return 1
return version + 1 if self.is_dirty else version
@validates_schema
def validate_query(self, value: dict, **kwargs):
"""Validate queries by calling into the validator for the relevant method."""
data: AnyRuleData = value["data"]
metadata: RuleMeta = value["metadata"]
beats_version = metadata.beats_version or beats.get_max_version()
ecs_versions = metadata.ecs_versions or [ecs.get_max_version()]
# call into these validate methods
if isinstance(data, (EQLRuleData, KQLRuleData)):
if metadata.query_schema_validation is False or metadata.maturity == "deprecated":
# Check the syntax only
_ = data.parsed_query
else:
# otherwise, do a full schema validation
data.validate_query(beats_version=beats_version, ecs_versions=ecs_versions)
def to_dict(self, strip_none_values=True) -> dict:
dict_obj = super(TOMLRuleContents, self).to_dict(strip_none_values=strip_none_values)
return nested_normalize(dict_obj)
def flattened_dict(self) -> dict:
flattened = dict()
flattened.update(self.data.to_dict())
flattened.update(self.metadata.to_dict())
return flattened
@staticmethod
def _post_dict_transform(obj: dict) -> dict:
"""Transform the converted API in place before sending to Kibana."""
# cleanup the whitespace in the rule
obj = nested_normalize(obj, eql_rule=obj.get("language") == "eql")
# fill in threat.technique so it's never missing
for threat_entry in obj.get("threat", []):
threat_entry.setdefault("technique", [])
return obj
def to_api_format(self, include_version=True) -> dict:
"""Convert the TOML rule to the API format."""
converted = self.data.to_dict()
if include_version:
converted["version"] = self.autobumped_version
converted = self._post_dict_transform(converted)
return converted
@cached
def sha256(self) -> str:
# get the hash of the API dict with the version not included, otherwise it'll always be dirty.
hashable_contents = self.to_api_format(include_version=False)
return utils.dict_hash(hashable_contents)
@dataclass
class TOMLRule:
contents: TOMLRuleContents = field(hash=True)
path: Path
gh_pr: Any = field(hash=False, compare=False, default=None, repr=None)
@property
def id(self):
return self.contents.id
@property
def name(self):
return self.contents.data.name
def save_toml(self):
converted = self.contents.to_dict()
toml_write(converted, str(self.path.absolute()))
def save_json(self, path: Path, include_version: bool = True):
with open(str(path.absolute()), 'w', newline='\n') as f:
json.dump(self.contents.to_api_format(include_version=include_version), f, sort_keys=True, indent=2)
f.write('\n')
def downgrade_contents_from_rule(rule: TOMLRule, target_version: str) -> dict:
"""Generate the downgraded contents from a rule."""
payload = rule.contents.to_api_format()
meta = payload.setdefault("meta", {})
meta["original"] = dict(id=rule.id, **rule.contents.metadata.to_dict())
payload["rule_id"] = str(uuid4())
payload = downgrade(payload, target_version)
return payload