Files

271 lines
8.7 KiB
Python
Raw Permalink Normal View History

import re
from functools import reduce
from typing import Dict, List, Literal, Optional, Union
from uuid import UUID
from pydantic import (
AnyUrl,
BaseModel,
ConfigDict,
Field,
IPvAnyAddress,
StrictFloat,
StringConstraints,
conlist,
constr,
field_serializer,
field_validator,
2024-07-16 13:06:56 -05:00
model_validator,
)
from pydantic_core import PydanticCustomError
from pydantic_core.core_schema import ValidationInfo
from typing_extensions import Annotated, TypedDict
InputArgType = Literal["url", "string", "float", "integer", "path"]
Platform = Literal[
"windows",
"macos",
"linux",
"office-365",
"azure-ad",
"google-workspace",
"saas",
"iaas",
"containers",
"iaas:gcp",
"iaas:azure",
"iaas:aws",
2025-05-01 11:12:40 -04:00
"esxi",
]
ExecutorType = Literal["manual", "powershell", "sh", "bash", "command_prompt"]
DomainName = Annotated[
str,
StringConstraints(
pattern=r"^(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z0-9][a-z0-9-]{0,61}[a-z0-9]$"
),
]
AttackTechniqueID = Annotated[
str, StringConstraints(pattern=r"T\d{4}(?:\.\d{3})?", min_length=5)
]
def extract_mustached_keys(commands: List[Optional[str]]) -> List[str]:
result = []
for command in commands:
if command:
matches = re.finditer(r"#{(.*?)}", command, re.MULTILINE)
keys = [list(i.groups()) for i in matches]
keys = list(reduce(lambda x, y: x + y, keys, []))
result.extend(keys)
return list(set(result))
def get_supported_platform(platform: Platform):
platforms = {
"macos": "macOS",
"office-365": "Office 365",
"windows": "Windows",
"linux": "Linux",
"azure-ad": "Azure AD",
"iaas": "IaaS",
"saas": "SaaS",
"iaas:aws": "AWS",
"iaas:azure": "Azure",
"iaas:gcp": "GCP",
"google-workspace": "Google Workspace",
"containers": "Containers",
2025-05-01 11:12:40 -04:00
"esxi": "ESXi",
}
return platforms[platform]
def get_language(executor: ExecutorType):
if executor == "command_prompt":
return "cmd"
elif executor == "manual":
return ""
return executor
class BaseArgument(TypedDict):
description: str
class UrlArg(BaseArgument):
default: Optional[DomainName | AnyUrl | IPvAnyAddress]
type: Literal["url", "Url"]
@field_serializer("default")
def serialize_url(self, value):
return str(value)
class StringArg(BaseArgument):
default: Optional[str]
type: Literal["string", "path", "String", "Path"]
class IntArg(BaseArgument):
2024-07-10 08:54:26 -05:00
default: Optional[int]
type: Literal["integer", "Integer"]
class FloatArg(BaseArgument):
default: Optional[StrictFloat]
type: Literal["float", "Float"]
Argument = Annotated[
Union[FloatArg, IntArg, UrlArg, StringArg], Field(discriminator="type")
]
class StrictModel(BaseModel):
model_config = ConfigDict(
validate_default=True, extra="forbid", validate_assignment=True
)
class Executor(StrictModel):
name: ExecutorType
elevation_required: bool = False
class ManualExecutor(Executor):
name: Literal["manual"]
steps: str = Field(..., min_length=10)
class CommandExecutor(Executor):
name: Literal["powershell", "sh", "bash", "command_prompt"]
command: constr(min_length=1)
cleanup_command: Optional[str] = None
class Dependency(StrictModel):
description: constr(min_length=1)
prereq_command: constr(min_length=1)
get_prereq_command: Optional[str]
class Atomic(StrictModel):
test_number: Optional[str] = None
name: constr(min_length=1)
description: constr(min_length=1)
supported_platforms: conlist(Platform, min_length=1)
executor: Union[ManualExecutor, CommandExecutor] = Field(..., discriminator="name")
dependencies: Optional[List[Dependency]] = []
2024-07-10 08:54:26 -05:00
input_arguments: Dict[constr(min_length=2, pattern=r"^[\w_-]+$"), Argument] = {}
2025-09-05 06:51:14 -04:00
dependency_executor_name: Optional[ExecutorType] = None
auto_generated_guid: Optional[UUID] = None
@classmethod
def extract_mustached_keys(cls, value: dict) -> List[str]:
commands = []
executor = value.get("executor")
if isinstance(executor, CommandExecutor):
commands = [executor.command, executor.cleanup_command]
if isinstance(executor, ManualExecutor):
commands = [executor.steps]
for d in value.get("dependencies") or []:
commands.extend([d.get_prereq_command, d.prereq_command])
return extract_mustached_keys(commands)
2024-07-10 08:54:26 -05:00
@field_validator("dependency_executor_name", mode="before") # noqa
@classmethod
def validate_dep_executor(cls, v, info: ValidationInfo):
2025-09-05 06:51:14 -04:00
if v is not None and info.data.get("dependencies") == []:
2024-07-10 08:54:26 -05:00
raise PydanticCustomError(
2025-09-05 06:51:14 -04:00
"invalid_dependency_executor_name",
"'dependency_executor_name' is not needed if there are no dependencies. Remove the key from YAML",
2024-07-10 08:54:26 -05:00
{"loc": ["dependency_executor_name"], "input": None},
)
return v
2024-07-16 13:06:56 -05:00
@model_validator(mode="after")
def validate_elevation_required(self):
if (
("linux" in self.supported_platforms or "macos" in self.supported_platforms)
and not self.executor.elevation_required
and isinstance(self.executor, CommandExecutor)
):
commands = [self.executor.command]
if self.executor.cleanup_command:
commands.append(self.executor.cleanup_command)
if any(["sudo" in cmd for cmd in commands]):
raise PydanticCustomError(
"elevation_required_but_not_provided",
"'elevation_required' shouldn't be empty/false. Since `sudo` is used, set `elevation_required` to true`",
{
"loc": ["executor", "elevation_required"],
"input": self.executor.elevation_required,
},
)
return self
@field_validator("input_arguments", mode="before") # noqa
@classmethod
def validate(cls, v, info: ValidationInfo):
if v is None:
2024-07-10 08:54:26 -05:00
raise PydanticCustomError(
"empty_input_arguments",
"'input_arguments' shouldn't be empty. Provide a valid value or remove the key from YAML",
{"loc": ["input_arguments"], "input": None},
)
atomic = info.data
keys = cls.extract_mustached_keys(atomic)
for key, _value in v.items():
if key not in keys:
raise PydanticCustomError(
"unused_input_argument",
f"'{key}' is not used in any of the commands",
{"loc": ["input_arguments", key], "input": key},
)
else:
keys.remove(key)
if len(keys) > 0:
for x in keys:
raise PydanticCustomError(
"missing_input_argument",
f"{x} is not defined in input_arguments",
{"loc": ["input_arguments"]},
)
return v
class Technique(StrictModel):
attack_technique: AttackTechniqueID
display_name: str = Field(..., min_length=5)
atomic_tests: List[Atomic] = Field(min_length=1)
2025-09-05 06:51:14 -04:00
@model_validator(mode="before")
@classmethod
def validate_dependency_executor_names(cls, data):
"""Check if dependency_executor_name keys are present with empty/None values in atomic tests"""
if isinstance(data, dict) and "atomic_tests" in data:
atomic_tests = data.get("atomic_tests", [])
for i, test in enumerate(atomic_tests):
if isinstance(test, dict) and "dependency_executor_name" in test:
value = test.get("dependency_executor_name")
# If the key exists but value is None or empty string, that's an error
if value is None or value == "":
raise PydanticCustomError(
"empty_dependency_executor_name",
"'dependency_executor_name' shouldn't be empty. Provide a valid value ['manual','powershell', 'sh', "
"'bash', 'command_prompt'] or remove the key from YAML",
2026-05-02 18:30:22 -04:00
{
"loc": ["atomic_tests", i, "dependency_executor_name"],
"input": value,
},
2025-09-05 06:51:14 -04:00
)
return data
def model_post_init(self, __context) -> None:
for index in range(len(self.atomic_tests)):
test_number = f"{self.attack_technique}-{index + 1}"
self.atomic_tests[index].test_number = test_number