151 lines
5.2 KiB
Python
151 lines
5.2 KiB
Python
import json
|
|
import re
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Pattern
|
|
|
|
from atomic_red_team.common import base_path
|
|
|
|
|
|
PlatformFilter = str | Pattern[str]
|
|
|
|
|
|
class Attack:
|
|
"""Loads MITRE ATT&CK technique metadata used by documentation generation."""
|
|
|
|
def __init__(self, attack_file: str | Path | None = None):
|
|
self.attack_file = (
|
|
Path(attack_file)
|
|
if attack_file
|
|
else Path(base_path) / "atomic_red_team" / "enterprise-attack.json"
|
|
)
|
|
self._techniques: list[dict] | None = None
|
|
self._techniques_by_id: dict[str, dict] | None = None
|
|
|
|
def ordered_tactics(self) -> list[str]:
|
|
return [
|
|
"reconnaissance",
|
|
"resource-development",
|
|
"initial-access",
|
|
"execution",
|
|
"persistence",
|
|
"privilege-escalation",
|
|
"stealth",
|
|
"defense-impairment",
|
|
"credential-access",
|
|
"discovery",
|
|
"lateral-movement",
|
|
"collection",
|
|
"command-and-control",
|
|
"exfiltration",
|
|
"impact",
|
|
]
|
|
|
|
def technique_identifier_for_technique(self, technique: dict) -> str:
|
|
reference = next(
|
|
ref
|
|
for ref in technique.get("external_references", [])
|
|
if ref.get("source_name") == "mitre-attack"
|
|
)
|
|
return reference["external_id"].upper()
|
|
|
|
def technique_info(self, technique_id: str) -> dict | None:
|
|
return self.techniques_by_id.get(technique_id.upper())
|
|
|
|
def ordered_tactic_to_technique_matrix(
|
|
self,
|
|
only_platform: PlatformFilter = ".*",
|
|
) -> list[list[dict | None]]:
|
|
techniques_by_tactic = self.techniques_by_tactic(only_platform=only_platform)
|
|
tactic_order = [
|
|
techniques_by_tactic[tactic] for tactic in self.ordered_tactics()
|
|
]
|
|
max_techniques = max(
|
|
(len(techniques) for techniques in tactic_order), default=0
|
|
)
|
|
if max_techniques == 0:
|
|
return []
|
|
for techniques in tactic_order:
|
|
techniques.extend([None] * (max_techniques - len(techniques)))
|
|
return [list(row) for row in zip(*tactic_order)]
|
|
|
|
def techniques_by_tactic(
|
|
self,
|
|
only_platform: PlatformFilter = ".*",
|
|
) -> dict[str, list[dict]]:
|
|
result: dict[str, list[dict]] = defaultdict(list)
|
|
for technique in self.techniques:
|
|
platforms = technique.get("x_mitre_platforms")
|
|
if not platforms:
|
|
continue
|
|
if not any(
|
|
_matches_platform(platform, only_platform) for platform in platforms
|
|
):
|
|
continue
|
|
if technique.get("revoked", False) or technique.get(
|
|
"x_mitre_deprecated", False
|
|
):
|
|
continue
|
|
for tactic in technique.get("kill_chain_phases", []):
|
|
if tactic.get("kill_chain_name") == "mitre-attack":
|
|
result[tactic["phase_name"]].append(technique)
|
|
return result
|
|
|
|
@property
|
|
def techniques(self) -> list[dict]:
|
|
if self._techniques is None:
|
|
raw = [
|
|
obj
|
|
for obj in json.loads(self.attack_file.read_text())["objects"]
|
|
if obj.get("type") == "attack-pattern"
|
|
]
|
|
id_to_name = {
|
|
_attack_id(obj): obj["name"]
|
|
for obj in raw
|
|
if _attack_id(obj) is not None
|
|
}
|
|
techniques = []
|
|
for obj in raw:
|
|
tid = _attack_id(obj)
|
|
if tid is None:
|
|
continue
|
|
t = dict(obj)
|
|
if "." in tid:
|
|
parent_name = id_to_name.get(tid.split(".")[0])
|
|
if parent_name:
|
|
t["name"] = f"{parent_name}: {t['name']}"
|
|
techniques.append(t)
|
|
self._techniques = techniques
|
|
return self._techniques
|
|
|
|
@property
|
|
def techniques_by_id(self) -> dict[str, dict]:
|
|
if self._techniques_by_id is None:
|
|
self._techniques_by_id = {
|
|
tid: t
|
|
for t in self.techniques
|
|
if (tid := _attack_id(t)) is not None
|
|
}
|
|
return self._techniques_by_id
|
|
|
|
|
|
def _attack_id(technique: dict) -> str | None:
|
|
"""Return the ATT&CK ID (e.g. 'T1059.001') for a technique dict, or None."""
|
|
for ref in technique.get("external_references", []):
|
|
if ref.get("source_name") == "mitre-attack":
|
|
return ref["external_id"].upper()
|
|
return None
|
|
|
|
|
|
def _matches_platform(platform: str, only_platform: PlatformFilter) -> bool:
|
|
# Normalize STIX platform names ("Azure AD" -> "azure-ad", "IaaS" -> "iaas") so
|
|
# they match the hyphenated YAML platform strings used as filters.
|
|
normalized = platform.lower().replace(" ", "-")
|
|
if isinstance(only_platform, str):
|
|
# Also match in the other direction so "iaas" matches filter "iaas:gcp".
|
|
return (
|
|
re.search(only_platform, normalized) is not None
|
|
or re.search(re.escape(normalized), only_platform) is not None
|
|
)
|
|
return only_platform.search(normalized) is not None
|