Files

416 lines
15 KiB
Python
Raw Permalink Normal View History

2020-06-29 23:17:38 -06:00
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
2021-03-03 22:12:11 -09:00
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
2020-06-29 23:17:38 -06:00
"""Misc support."""
2020-06-29 23:17:38 -06:00
import os
2024-08-06 18:07:12 -04:00
import unittest
from collections.abc import Callable
from functools import wraps
from pathlib import Path
from typing import IO, Any, NoReturn
2020-06-29 23:17:38 -06:00
import click
import requests
from elasticsearch import AuthenticationException, Elasticsearch
from kibana import Kibana # type: ignore[reportMissingTypeStubs]
2020-06-29 23:17:38 -06:00
from .utils import add_params, cached, load_etc_dump
2020-06-29 23:17:38 -06:00
LICENSE_HEADER = """
Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
2021-03-03 22:12:11 -09:00
or more contributor license agreements. Licensed under the Elastic License
2.0; you may not use this file except in compliance with the Elastic License
2.0.
""".strip()
LICENSE_LINES = LICENSE_HEADER.splitlines()
PYTHON_LICENSE = "\n".join("# " + line for line in LICENSE_LINES)
JS_LICENSE = """
/*
{}
*/
""".strip().format("\n".join(" * " + line for line in LICENSE_LINES))
2020-06-29 23:17:38 -06:00
2024-08-06 18:07:12 -04:00
ROOT_DIR = Path(__file__).parent.parent
class ClientError(click.ClickException):
"""Custom CLI error to format output or full debug stacktrace."""
def __init__(self, message: str, original_error: Exception | None = None) -> None:
super().__init__(message)
self.original_error = original_error
self.original_error_type = type(original_error).__name__ if original_error else ""
def show(self, file: IO[Any] | None = None, err: bool = True) -> None:
"""Print the error to the console."""
header = f"Error ({self.original_error_type})"
msg = f"{click.style(header, fg='red', bold=True)}: {self.format_message()}"
click.echo(msg, err=err, file=file)
def raise_client_error( # noqa: PLR0913
message: str,
exc: Exception | None = None,
debug: bool | None = False,
ctx: click.Context | None = None,
file: IO[Any] | None = None,
err: bool = False,
) -> NoReturn:
config_debug = bool(ctx and ctx.ensure_object(dict) and ctx.obj.get("debug")) # type: ignore[reportUnknownArgumentType]
debug = debug if debug is not None else config_debug
if debug:
click.echo(click.style("DEBUG: ", fg="yellow") + message, err=err, file=file)
raise ClientError(message, original_error=exc)
raise ClientError(message, original_error=exc)
def nested_get(_dict: dict[str, Any] | None, dot_key: str | None, default: Any | None = None) -> Any:
2020-06-29 23:17:38 -06:00
"""Get a nested field from a nested dict with dot notation."""
if _dict is None or dot_key is None:
return default
if "." in dot_key:
dot_key_parts = dot_key.split(".")
this_key = dot_key_parts.pop(0)
return nested_get(_dict.get(this_key, default), ".".join(dot_key_parts), default)
return _dict.get(dot_key, default)
2020-06-29 23:17:38 -06:00
def nested_set(_dict: dict[str, Any], dot_key: str, value: Any) -> None:
2022-09-06 15:53:47 -06:00
"""Set a nested field from a key in dot notation."""
keys = dot_key.split(".")
2020-09-23 17:36:34 -05:00
for key in keys[:-1]:
2020-06-29 23:17:38 -06:00
_dict = _dict.setdefault(key, {})
_dict[keys[-1]] = value
2020-06-29 23:17:38 -06:00
def nest_from_dot(dots: str, value: Any) -> Any:
2022-09-06 15:53:47 -06:00
"""Nest a dotted field and set the innermost value."""
fields = dots.split(".")
2022-09-06 15:53:47 -06:00
if not fields:
return {}
nested = {fields.pop(): value}
2024-08-06 18:07:12 -04:00
for field_ in reversed(fields):
nested = {field_: nested}
2022-09-06 15:53:47 -06:00
return nested
def get_kibana_rules_map(repo: str = "elastic/kibana", branch: str = "master") -> dict[str, Any]:
2020-06-29 23:17:38 -06:00
"""Get list of available rules from the Kibana repo and return a list of URLs."""
timeout = 30 # secs
2020-09-02 09:19:17 -08:00
# ensure branch exists
r = requests.get(f"https://api.github.com/repos/{repo}/branches/{branch}", timeout=timeout)
2020-09-02 09:19:17 -08:00
r.raise_for_status()
2020-06-29 23:17:38 -06:00
url = (
"https://api.github.com/repos/{repo}/contents/x-pack/{legacy}plugins/{app}/server/lib/"
"detection_engine/rules/prepackaged_rules?ref={branch}"
)
r = requests.get(url.format(legacy="", app="security_solution", branch=branch, repo=repo), timeout=timeout)
r.raise_for_status()
2020-06-29 23:17:38 -06:00
gh_rules = r.json()
2020-09-02 09:19:17 -08:00
# pre-7.9 app was siem
if isinstance(gh_rules, dict) and gh_rules.get("message", "") == "Not Found": # type: ignore[reportUnknownMemberType]
gh_rules = requests.get(url.format(legacy="", app="siem", branch=branch, repo=repo), timeout=timeout).json()
2020-06-29 23:17:38 -06:00
# pre-7.8 the siem was under the legacy directory
if isinstance(gh_rules, dict) and gh_rules.get("message", "") == "Not Found": # type: ignore[reportUnknownMemberType]
gh_rules = requests.get(
url.format(legacy="legacy/", app="siem", branch=branch, repo=repo), timeout=timeout
).json()
if isinstance(gh_rules, dict) and gh_rules.get("message", "") == "Not Found": # type: ignore[reportUnknownMemberType]
raise ValueError(f"rules directory does not exist for {repo} branch: {branch}")
if not isinstance(gh_rules, list):
raise TypeError("Expected to receive a list")
2020-09-02 09:19:17 -08:00
results: dict[str, Any] = {}
2020-06-29 23:17:38 -06:00
for r in gh_rules: # type: ignore[reportUnknownMemberType]
if "name" not in r:
raise ValueError("Name value is expected")
2020-06-29 23:17:38 -06:00
name = r["name"] # type: ignore[reportUnknownMemberType]
2020-06-29 23:17:38 -06:00
if not isinstance(name, str):
raise TypeError("String value is expected for name")
if name.endswith(".json"):
key = Path(name).name
val = r["download_url"] # type: ignore[reportUnknownMemberType]
results[key] = val
return results
def get_kibana_rules(
repo: str = "elastic/kibana",
branch: str = "master",
verbose: bool = True,
threads: int = 50,
rule_paths: list[str] | None = None,
) -> dict[str, Any]:
2020-06-29 23:17:38 -06:00
"""Retrieve prepackaged rules from kibana repo."""
2020-09-02 09:19:17 -08:00
from multiprocessing.pool import ThreadPool
kibana_rules: dict[str, Any] = {}
2020-09-02 09:19:17 -08:00
2020-06-29 23:17:38 -06:00
if verbose:
thread_use = f" using {threads} threads" if threads > 1 else ""
click.echo(f"Downloading rules from {repo} {branch} branch in kibana repo{thread_use} ...")
2020-06-29 23:17:38 -06:00
rule_paths = [os.path.splitext(os.path.basename(p))[0] for p in (rule_paths or [])] # noqa: PTH119, PTH122
rules_mapping = (
[(n, u) for n, u in get_kibana_rules_map(repo=repo, branch=branch).items() if n in rule_paths]
if rule_paths
else get_kibana_rules_map(repo=repo, branch=branch).items()
)
2020-09-02 09:19:17 -08:00
def download_worker(rule_info: tuple[str, str]) -> None:
2020-09-02 09:19:17 -08:00
n, u = rule_info
kibana_rules[n] = requests.get(u, timeout=30).json()
2020-09-02 09:19:17 -08:00
pool = ThreadPool(processes=threads)
_ = pool.map(download_worker, rules_mapping)
2020-09-02 09:19:17 -08:00
pool.close()
pool.join()
return kibana_rules
2020-06-29 23:17:38 -06:00
@cached
def load_current_package_version() -> str:
"""Load the current package version from config file."""
data = load_etc_dump(["packages.yaml"])
return data["package"]["name"]
def get_default_config() -> Path | None:
return next(ROOT_DIR.glob(".detection-rules-cfg.*"), None)
@cached
def parse_user_config() -> dict[str, Any]:
2020-06-29 23:17:38 -06:00
"""Parse a default config file."""
import eql # type: ignore[reportMissingTypeStubs]
config_file = get_default_config()
config = {}
2020-06-29 23:17:38 -06:00
if config_file and config_file.exists():
config = eql.utils.load_dump(str(config_file)) # type: ignore[reportUnknownMemberType]
click.secho(f"Loaded config file: {config_file}", fg="yellow")
2020-06-29 23:17:38 -06:00
return config
2020-06-29 23:17:38 -06:00
def discover_tests(start_dir: str = "tests", pattern: str = "test*.py", top_level_dir: str | None = None) -> list[str]:
2024-08-06 18:07:12 -04:00
"""Discover all unit tests in a directory."""
tests: list[str] = []
def list_tests(s: unittest.TestSuite) -> None:
2024-08-06 18:07:12 -04:00
for test in s:
if isinstance(test, unittest.TestSuite):
list_tests(test)
2024-08-06 18:07:12 -04:00
else:
tests.append(test.id())
loader = unittest.defaultTestLoader
suite = loader.discover(start_dir, pattern=pattern, top_level_dir=top_level_dir or str(ROOT_DIR))
list_tests(suite)
return tests
2024-08-06 18:07:12 -04:00
def getdefault(name: str) -> Callable[[], Any]:
"""Callback function for `default` to get an environment variable."""
envvar = f"DR_{name.upper()}"
2024-08-06 18:07:12 -04:00
config = parse_user_config()
return lambda: os.environ.get(envvar, config.get(name))
def get_elasticsearch_client( # noqa: PLR0913
cloud_id: str | None = None,
elasticsearch_url: str | None = None,
es_user: str | None = None,
es_password: str | None = None,
ctx: click.Context | None = None,
api_key: str | None = None,
**kwargs: Any,
) -> Elasticsearch:
"""Get an authenticated elasticsearch client."""
# Handle empty strings as None
cloud_id = cloud_id or None
elasticsearch_url = elasticsearch_url or None
if not (cloud_id or elasticsearch_url):
raise_client_error("Missing required --cloud-id or --elasticsearch-url")
# don't prompt for these until there's a cloud id or elasticsearch URL
basic_auth: tuple[str, str] | None = None
2024-07-11 14:19:41 -07:00
if not api_key:
es_user = es_user or click.prompt("es_user")
es_password = es_password or click.prompt("es_password", hide_input=True)
if not es_user or not es_password:
raise ValueError("Both username and password must be provided")
2024-07-11 14:19:41 -07:00
basic_auth = (es_user, es_password)
hosts = [elasticsearch_url] if elasticsearch_url else None
timeout = kwargs.pop("timeout", 60)
kwargs["verify_certs"] = not kwargs.pop("ignore_ssl_errors", False)
try:
client = Elasticsearch(
hosts=hosts, cloud_id=cloud_id, http_auth=basic_auth, timeout=timeout, api_key=api_key, **kwargs
)
# force login to test auth
_ = client.info()
except AuthenticationException as e:
error_msg = f"Failed authentication for {elasticsearch_url or cloud_id}"
raise_client_error(error_msg, e, ctx=ctx, err=True)
else:
return client
def get_default_elasticsearch_client() -> Elasticsearch:
"""Get an default authenticated elasticsearch client."""
return get_elasticsearch_client(
api_key=getdefault("api_key")(),
cloud_id=getdefault("cloud_id")(),
elasticsearch_url=getdefault("elasticsearch_url")(),
ignore_ssl_errors=getdefault("ignore_ssl_errors")(),
)
def get_kibana_client(
*,
api_key: str,
cloud_id: str | None = None,
kibana_url: str | None = None,
space: str | None = None,
ignore_ssl_errors: bool = False,
**kwargs: Any,
) -> Kibana:
"""Get an authenticated Kibana client."""
if not (cloud_id or kibana_url):
raise_client_error("Missing required --cloud-id or --kibana-url")
2021-09-10 10:06:04 -08:00
verify = not ignore_ssl_errors
return Kibana(cloud_id=cloud_id, kibana_url=kibana_url, space=space, verify=verify, api_key=api_key, **kwargs)
def get_default_kibana_client() -> Kibana:
"""Get a default authenticated Kibana client."""
return get_kibana_client(
api_key=getdefault("api_key")(),
cloud_id=getdefault("cloud_id")(),
kibana_url=getdefault("kibana_url")(),
space=getdefault("space")(),
ignore_ssl_errors=getdefault("ignore_ssl_errors")(),
)
client_options = {
"kibana": {
"kibana_url": click.Option(["--kibana-url"], default=getdefault("kibana_url")),
"cloud_id": click.Option(["--cloud-id"], default=getdefault("cloud_id"), help="ID of the cloud instance."),
"api_key": click.Option(["--api-key"], default=getdefault("api_key")),
"space": click.Option(["--space"], default=None, help="Kibana space"),
"ignore_ssl_errors": click.Option(["--ignore-ssl-errors"], default=getdefault("ignore_ssl_errors")),
},
"elasticsearch": {
"cloud_id": click.Option(["--cloud-id"], default=getdefault("cloud_id")),
"api_key": click.Option(["--api-key"], default=getdefault("api_key")),
"elasticsearch_url": click.Option(["--elasticsearch-url"], default=getdefault("elasticsearch_url")),
"es_user": click.Option(["--es-user", "-eu"], default=getdefault("es_user")),
"es_password": click.Option(["--es-password", "-ep"], default=getdefault("es_password")),
"timeout": click.Option(["--timeout", "-et"], default=60, help="Timeout for elasticsearch client"),
"ignore_ssl_errors": click.Option(["--ignore-ssl-errors"], default=getdefault("ignore_ssl_errors")),
},
}
kibana_options = list(client_options["kibana"].values())
elasticsearch_options = list(client_options["elasticsearch"].values())
def add_client(client_types: list[str], add_to_ctx: bool = True, add_func_arg: bool = True) -> Callable[..., Any]:
"""Wrapper to add authed client."""
def _wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
client_ops_dict: dict[str, click.Option] = {}
client_ops_keys: dict[str, list[str]] = {}
for c_type in client_types:
ops = client_options[c_type]
client_ops_dict.update(ops)
client_ops_keys[c_type] = list(ops)
if not client_ops_dict:
client_types_str = ", ".join(client_types)
raise ValueError(f"Unknown client: {client_types_str} in {func.__name__}")
client_ops = list(client_ops_dict.values())
@wraps(func)
@add_params(*client_ops)
def _wrapped(*args: Any, **kwargs: Any) -> Any: # noqa: PLR0912
ctx: click.Context | None = next((a for a in args if isinstance(a, click.Context)), None)
es_client_args = {k: kwargs.pop(k, None) for k in client_ops_keys.get("elasticsearch", [])}
# shared args like cloud_id
kibana_client_args = {k: kwargs.pop(k, es_client_args.get(k)) for k in client_ops_keys.get("kibana", [])}
if "elasticsearch" in client_types:
# for nested ctx invocation, no need to re-auth if an existing client is already passed
elasticsearch_client: Elasticsearch | None = kwargs.get("elasticsearch_client")
try:
if elasticsearch_client and elasticsearch_client.info():
pass
else:
elasticsearch_client = get_elasticsearch_client(**es_client_args)
except AuthenticationException:
elasticsearch_client = get_elasticsearch_client(**es_client_args)
if add_func_arg:
kwargs["elasticsearch_client"] = elasticsearch_client
if ctx and add_to_ctx:
ctx.obj["es"] = elasticsearch_client
if "kibana" in client_types:
# for nested ctx invocation, no need to re-auth if an existing client is already passed
kibana_client: Kibana | None = kwargs.get("kibana_client")
if kibana_client:
try:
with kibana_client:
if kibana_client.version:
pass # kibana_client is valid and can be used directly
except (requests.HTTPError, AttributeError):
kibana_client = get_kibana_client(**kibana_client_args)
else:
# Instantiate a new Kibana client if none was provided or if the provided one is not usable
kibana_client = get_kibana_client(**kibana_client_args)
if add_func_arg:
kwargs["kibana_client"] = kibana_client
if ctx and add_to_ctx:
ctx.obj["kibana"] = kibana_client
return func(*args, **kwargs)
return _wrapped
return _wrapper