Files
atomic-red-team/atomic_red_team/attack_api.py
T
2026-05-02 18:30:22 -04:00

151 lines
5.2 KiB
Python

import json
import re
from collections import defaultdict
from pathlib import Path
from typing import Pattern
from atomic_red_team.common import base_path
PlatformFilter = str | Pattern[str]
class Attack:
"""Loads MITRE ATT&CK technique metadata used by documentation generation."""
def __init__(self, attack_file: str | Path | None = None):
self.attack_file = (
Path(attack_file)
if attack_file
else Path(base_path) / "atomic_red_team" / "enterprise-attack.json"
)
self._techniques: list[dict] | None = None
self._techniques_by_id: dict[str, dict] | None = None
def ordered_tactics(self) -> list[str]:
return [
"reconnaissance",
"resource-development",
"initial-access",
"execution",
"persistence",
"privilege-escalation",
"stealth",
"defense-impairment",
"credential-access",
"discovery",
"lateral-movement",
"collection",
"command-and-control",
"exfiltration",
"impact",
]
def technique_identifier_for_technique(self, technique: dict) -> str:
reference = next(
ref
for ref in technique.get("external_references", [])
if ref.get("source_name") == "mitre-attack"
)
return reference["external_id"].upper()
def technique_info(self, technique_id: str) -> dict | None:
return self.techniques_by_id.get(technique_id.upper())
def ordered_tactic_to_technique_matrix(
self,
only_platform: PlatformFilter = ".*",
) -> list[list[dict | None]]:
techniques_by_tactic = self.techniques_by_tactic(only_platform=only_platform)
tactic_order = [
techniques_by_tactic[tactic] for tactic in self.ordered_tactics()
]
max_techniques = max(
(len(techniques) for techniques in tactic_order), default=0
)
if max_techniques == 0:
return []
for techniques in tactic_order:
techniques.extend([None] * (max_techniques - len(techniques)))
return [list(row) for row in zip(*tactic_order)]
def techniques_by_tactic(
self,
only_platform: PlatformFilter = ".*",
) -> dict[str, list[dict]]:
result: dict[str, list[dict]] = defaultdict(list)
for technique in self.techniques:
platforms = technique.get("x_mitre_platforms")
if not platforms:
continue
if not any(
_matches_platform(platform, only_platform) for platform in platforms
):
continue
if technique.get("revoked", False) or technique.get(
"x_mitre_deprecated", False
):
continue
for tactic in technique.get("kill_chain_phases", []):
if tactic.get("kill_chain_name") == "mitre-attack":
result[tactic["phase_name"]].append(technique)
return result
@property
def techniques(self) -> list[dict]:
if self._techniques is None:
raw = [
obj
for obj in json.loads(self.attack_file.read_text())["objects"]
if obj.get("type") == "attack-pattern"
]
id_to_name = {
_attack_id(obj): obj["name"]
for obj in raw
if _attack_id(obj) is not None
}
techniques = []
for obj in raw:
tid = _attack_id(obj)
if tid is None:
continue
t = dict(obj)
if "." in tid:
parent_name = id_to_name.get(tid.split(".")[0])
if parent_name:
t["name"] = f"{parent_name}: {t['name']}"
techniques.append(t)
self._techniques = techniques
return self._techniques
@property
def techniques_by_id(self) -> dict[str, dict]:
if self._techniques_by_id is None:
self._techniques_by_id = {
tid: t
for t in self.techniques
if (tid := _attack_id(t)) is not None
}
return self._techniques_by_id
def _attack_id(technique: dict) -> str | None:
"""Return the ATT&CK ID (e.g. 'T1059.001') for a technique dict, or None."""
for ref in technique.get("external_references", []):
if ref.get("source_name") == "mitre-attack":
return ref["external_id"].upper()
return None
def _matches_platform(platform: str, only_platform: PlatformFilter) -> bool:
# Normalize STIX platform names ("Azure AD" -> "azure-ad", "IaaS" -> "iaas") so
# they match the hyphenated YAML platform strings used as filters.
normalized = platform.lower().replace(" ", "-")
if isinstance(only_platform, str):
# Also match in the other direction so "iaas" matches filter "iaas:gcp".
return (
re.search(only_platform, normalized) is not None
or re.search(re.escape(normalized), only_platform) is not None
)
return only_platform.search(normalized) is not None