[FR] Add Support for ES|QL Rule Type and Remote Validation (#3281)

* add suuport for esql type
* add unit tests
* set clients in RemoteConnector from auth methods
* thread remote rules; add engine test
* Add versions to remote validation results

---------

Co-authored-by: Terrance DeJesus <99630311+terrancedejesus@users.noreply.github.com>
Co-authored-by: brokensound77 <brokensound77@users.noreply.github.com>
Co-authored-by: Justin Ibarra <16747370+brokensound77@users.noreply.github.com>

(cherry picked from commit 7514c0a206)
This commit is contained in:
Mika Ayenson
2023-12-08 13:46:28 -06:00
committed by github-actions[bot]
parent 87f8498b68
commit 111ce46b75
10 changed files with 497 additions and 159 deletions
+6 -2
View File
@@ -11,7 +11,7 @@ import uuid
from pathlib import Path
from functools import wraps
from typing import NoReturn
from typing import NoReturn, Optional
import click
import requests
@@ -270,12 +270,16 @@ def load_current_package_version() -> str:
return load_etc_dump('packages.yml')['package']['name']
def get_default_config() -> Optional[Path]:
return next(Path(get_path()).glob('.detection-rules-cfg.*'), None)
@cached
def parse_config():
"""Parse a default config file."""
import eql
config_file = next(Path(get_path()).glob('.detection-rules-cfg.*'), None)
config_file = get_default_config()
config = {}
if config_file and config_file.exists():
+203
View File
@@ -0,0 +1,203 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
from dataclasses import dataclass
from datetime import datetime
from functools import cached_property
from multiprocessing.pool import ThreadPool
from typing import Dict, List, Optional
import elasticsearch
from elasticsearch import Elasticsearch
from marshmallow import ValidationError
from requests import HTTPError
from kibana import Kibana
from .misc import ClientError, getdefault, get_elasticsearch_client, get_kibana_client, load_current_package_version
from .rule import TOMLRule, TOMLRuleContents
from .schemas import definitions
@dataclass
class RemoteValidationResult:
"""Dataclass for remote validation results."""
rule_id: definitions.UUIDString
rule_name: str
contents: dict
rule_version: int
stack_version: str
query_results: Optional[dict]
engine_results: Optional[dict]
class RemoteConnector:
"""Base client class for remote validation and testing."""
MAX_RETRIES = 5
def __init__(self, parse_config: bool = False, **kwargs):
es_args = ['cloud_id', 'ignore_ssl_errors', 'elasticsearch_url', 'es_user', 'es_password', 'timeout']
kibana_args = [
'cloud_id', 'ignore_ssl_errors', 'kibana_url', 'kibana_user', 'kibana_password', 'space', 'kibana_cookie',
'provider_type', 'provider_name'
]
if parse_config:
es_kwargs = {arg: getdefault(arg)() for arg in es_args}
kibana_kwargs = {arg: getdefault(arg)() for arg in kibana_args}
try:
if 'max_retries' not in es_kwargs:
es_kwargs['max_retries'] = self.MAX_RETRIES
self.es_client = get_elasticsearch_client(**es_kwargs, **kwargs)
except ClientError:
self.es_client = None
try:
self.kibana_client = get_kibana_client(**kibana_kwargs, **kwargs)
except HTTPError:
self.kibana_client = None
def auth_es(self, *, cloud_id: Optional[str] = None, ignore_ssl_errors: Optional[bool] = None,
elasticsearch_url: Optional[str] = None, es_user: Optional[str] = None,
es_password: Optional[str] = None, timeout: Optional[int] = None, **kwargs) -> Elasticsearch:
"""Return an authenticated Elasticsearch client."""
if 'max_retries' not in kwargs:
kwargs['max_retries'] = self.MAX_RETRIES
self.es_client = get_elasticsearch_client(cloud_id=cloud_id, ignore_ssl_errors=ignore_ssl_errors,
elasticsearch_url=elasticsearch_url, es_user=es_user,
es_password=es_password, timeout=timeout, **kwargs)
return self.es_client
def auth_kibana(self, *, cloud_id: Optional[str] = None, ignore_ssl_errors: Optional[bool] = None,
kibana_url: Optional[str] = None, kibana_user: Optional[str] = None,
kibana_password: Optional[str] = None, space: Optional[str] = None,
kibana_cookie: Optional[str] = None, provider_type: Optional[str] = None,
provider_name: Optional[str] = None, **kwargs) -> Kibana:
"""Return an authenticated Kibana client."""
self.kibana_client = get_kibana_client(cloud_id=cloud_id, ignore_ssl_errors=ignore_ssl_errors,
kibana_url=kibana_url, kibana_user=kibana_user,
kibana_password=kibana_password, space=space,
kibana_cookie=kibana_cookie, provider_type=provider_type,
provider_name=provider_name, **kwargs)
return self.kibana_client
class RemoteValidator(RemoteConnector):
"""Client class for remote validation."""
def __init__(self, parse_config: bool = False):
super(RemoteValidator, self).__init__(parse_config=parse_config)
@cached_property
def get_validate_methods(self) -> List[str]:
"""Return all validate methods."""
exempt = ('validate_rule', 'validate_rules')
methods = [m for m in self.__dir__() if m.startswith('validate_') and m not in exempt]
return methods
def get_validate_method(self, name: str) -> callable:
"""Return validate method by name."""
assert name in self.get_validate_methods, f'validate method {name} not found'
return getattr(self, name)
@staticmethod
def prep_for_preview(contents: TOMLRuleContents) -> dict:
"""Prepare rule for preview."""
end_time = datetime.utcnow().isoformat()
dumped = contents.to_api_format().copy()
dumped.update(timeframeEnd=end_time, invocationCount=1)
return dumped
def engine_preview(self, contents: TOMLRuleContents) -> dict:
"""Get results from detection engine preview API."""
dumped = self.prep_for_preview(contents)
return self.kibana_client.post('/api/detection_engine/rules/preview', json=dumped)
def validate_rule(self, contents: TOMLRuleContents) -> RemoteValidationResult:
"""Validate a single rule query."""
method = self.get_validate_method(f'validate_{contents.data.type}')
query_results = method(contents)
engine_results = self.engine_preview(contents)
rule_version = contents.autobumped_version
stack_version = load_current_package_version()
return RemoteValidationResult(contents.data.rule_id, contents.data.name, contents.to_api_format(),
rule_version, stack_version, query_results, engine_results)
def validate_rules(self, rules: List[TOMLRule], threads: int = 5) -> Dict[str, RemoteValidationResult]:
"""Validate a collection of rules via threads."""
responses = {}
def request(c: TOMLRuleContents):
try:
responses[c.data.rule_id] = self.validate_rule(c)
except ValidationError as e:
responses[c.data.rule_id] = e.messages
pool = ThreadPool(processes=threads)
pool.map(request, [r.contents for r in rules])
pool.close()
pool.join()
return responses
def validate_esql(self, contents: TOMLRuleContents) -> dict:
query = contents.data.query
rule_id = contents.data.rule_id
headers = {"accept": "application/json", "content-type": "application/json"}
body = {'query': f'{query} | LIMIT 0'}
try:
response = self.es_client.perform_request('POST', '/_query', headers=headers, params={'pretty': True},
body=body)
except Exception as exc:
if isinstance(exc, elasticsearch.BadRequestError):
raise ValidationError(f'ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}')
else:
raise Exception(f'ES|QL query failed for rule: {rule_id}, query: \n{query}') from exc
return response.body
def validate_eql(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "eql" rule types."""
query = contents.data.query
rule_id = contents.data.rule_id
index = contents.data.index
time_range = {"range": {"@timestamp": {"gt": 'now-1h/h', "lte": 'now', "format": "strict_date_optional_time"}}}
body = {'query': query}
try:
response = self.es_client.eql.search(index=index, body=body, ignore_unavailable=True, filter=time_range)
except Exception as exc:
if isinstance(exc, elasticsearch.BadRequestError):
raise ValidationError(f'EQL query failed: {exc} for rule: {rule_id}, query: \n{query}')
else:
raise Exception(f'EQL query failed for rule: {rule_id}, query: \n{query}') from exc
return response.body
@staticmethod
def validate_query(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "query" rule types."""
return {'results': 'Unable to remote validate query rules'}
@staticmethod
def validate_threshold(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "threshold" rule types."""
return {'results': 'Unable to remote validate threshold rules'}
@staticmethod
def validate_new_terms(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "new_terms" rule types."""
return {'results': 'Unable to remote validate new_terms rules'}
@staticmethod
def validate_threat_match(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "threat_match" rule types."""
return {'results': 'Unable to remote validate threat_match rules'}
@staticmethod
def validate_machine_learning(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "machine_learning" rule types."""
return {'results': 'Unable to remote validate machine_learning rules'}
+24 -21
View File
@@ -594,26 +594,11 @@ class QueryRuleData(BaseRuleData):
@validates_schema
def validates_query_data(self, data, **kwargs):
"""Custom validation for query rule type and subclasses."""
# alert suppression is only valid for query rule type and not any of its subclasses
if data.get('alert_suppression') and data['type'] != 'query':
raise ValidationError("Alert suppression is only valid for query rule type.")
@dataclass(frozen=True)
class ESQLRuleData(QueryRuleData):
"""ESQL rules are a special case of query rules."""
type: Literal["esql"]
language: Literal["esql"]
query: str
@validates_schema
def validate_esql_data(self, data, **kwargs):
"""Custom validation for esql rule type."""
if data.get('index'):
raise ValidationError("Index is not valid for esql rule type.")
@dataclass(frozen=True)
class MachineLearningRuleData(BaseRuleData):
type: Literal["machine_learning"]
@@ -726,6 +711,20 @@ class EQLRuleData(QueryRuleData):
return interval / self.max_span
@dataclass(frozen=True)
class ESQLRuleData(QueryRuleData):
"""ESQL rules are a special case of query rules."""
type: Literal["esql"]
language: Literal["esql"]
query: str
@validates_schema
def validates_esql_data(self, data, **kwargs):
"""Custom validation for query rule type and subclasses."""
if data.get('index'):
raise ValidationError("Index is not a valid field for ES|QL rule type.")
@dataclass(frozen=True)
class ThreatMatchRuleData(QueryRuleData):
"""Specific fields for indicator (threat) match rule."""
@@ -1096,12 +1095,11 @@ class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin):
packaged_integrations = []
datasets = set()
if data.type != "esql":
for node in data.get('ast', []):
if isinstance(node, eql.ast.Comparison) and str(node.left) == 'event.dataset':
datasets.update(set(n.value for n in node if isinstance(n, eql.ast.Literal)))
elif isinstance(node, FieldComparison) and str(node.field) == 'event.dataset':
datasets.update(set(str(n) for n in node if isinstance(n, kql.ast.Value)))
for node in data.get('ast') or []:
if isinstance(node, eql.ast.Comparison) and str(node.left) == 'event.dataset':
datasets.update(set(n.value for n in node if isinstance(n, eql.ast.Literal)))
elif isinstance(node, FieldComparison) and str(node.field) == 'event.dataset':
datasets.update(set(str(n) for n in node if isinstance(n, kql.ast.Value)))
# integration is None to remove duplicate references upstream in Kibana
# chronologically, event.dataset is checked for package:integration, then rule tags
@@ -1139,6 +1137,10 @@ class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin):
data.data_validator.validate_bbr(metadata.get('bypass_bbr_timing'))
data.validate(metadata) if hasattr(data, 'validate') else False
@staticmethod
def validate_remote(remote_validator: 'RemoteValidator', contents: 'TOMLRuleContents'):
remote_validator.validate_rule(contents)
def to_dict(self, strip_none_values=True) -> dict:
# Load schemas directly from the data and metadata classes to avoid schema ambiguity which can
# result from union fields which contain classes and related subclasses (AnyRuleData). See issue #1141
@@ -1347,3 +1349,4 @@ def get_unique_query_fields(rule: TOMLRule) -> List[str]:
# avoid a circular import
from .rule_validators import EQLValidator, ESQLValidator, KQLValidator # noqa: E402
from .remote_validation import RemoteValidator # noqa: E402
+12 -5
View File
@@ -5,18 +5,21 @@
"""Validation logic for rules containing queries."""
from functools import cached_property
from typing import List, Optional, Union, Tuple
from semver import Version
from typing import List, Optional, Tuple, Union
import eql
from marshmallow import ValidationError
from semver import Version
import kql
from . import ecs, endgame
from .integrations import get_integration_schema_data, load_integrations_manifests
from .integrations import (get_integration_schema_data,
load_integrations_manifests)
from .misc import load_current_package_version
from .rule import (EQLRuleData, QueryRuleData, QueryValidator, RuleMeta,
TOMLRuleContents, set_eql_config)
from .schemas import get_stack_schemas
from .rule import QueryRuleData, QueryValidator, RuleMeta, TOMLRuleContents, EQLRuleData, set_eql_config
EQL_ERROR_TYPES = Union[eql.EqlCompileError,
eql.EqlError,
@@ -351,7 +354,6 @@ class ESQLValidator(QueryValidator):
@cached_property
def ast(self):
"""Return an AST."""
return None
@cached_property
@@ -365,6 +367,11 @@ class ESQLValidator(QueryValidator):
"""Validate an ESQL query while checking TOMLRule."""
# temporarily override to NOP until ES|QL query parsing is supported
def validate_integration(self, data: QueryRuleData, meta: RuleMeta, package_integrations: List[dict]) -> Union[
ValidationError, None, ValueError]:
# return self.validate(data, meta)
pass
def extract_error_field(exc: Union[eql.EqlParseError, kql.KqlParseError]) -> Optional[str]:
"""Extract the field name from an EQL or KQL parse error."""
+1 -1
View File
@@ -138,7 +138,7 @@ CardinalityFields = NewType('CardinalityFields', List[NonEmptyStr], validate=val
CodeString = NewType("CodeString", str)
ConditionSemVer = NewType('ConditionSemVer', str, validate=validate.Regexp(CONDITION_VERSION_PATTERN))
Date = NewType('Date', str, validate=validate.Regexp(DATE_PATTERN))
FilterLanguages = Literal["kuery", "lucene", "eql", "esql"]
FilterLanguages = Literal["eql", "esql", "kuery", "lucene"]
Interval = NewType('Interval', str, validate=validate.Regexp(INTERVAL_PATTERN))
InvestigateProviderQueryType = Literal["phrase", "range"]
InvestigateProviderValueType = Literal["string", "boolean"]
+4
View File
@@ -0,0 +1,4 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
@@ -0,0 +1,37 @@
[metadata]
creation_date = "2023/11/20"
integration = ["endpoint"]
maturity = "production"
min_stack_comments = "ES|QL Rule"
min_stack_version = "8.11.0"
updated_date = "2023/11/20"
[rule]
author = ["Elastic"]
description = """
Sample ES|QL rule for unit tests.
"""
from = "now-9m"
language = "esql"
license = "Elastic License v2"
name = "Sample ES|QL rule for unit tests"
risk_score = 47
rule_id = "24220495-cffe-45a1-996c-37b599ba0e43"
severity = "medium"
tags = ["Data Source: Elastic Endpoint", "Domain: Endpoint", "OS: Windows", "Use Case: Threat Detection", "Tactic: Command and Control", "Data Source: Elastic Defend"]
timestamp_override = "event.ingested"
type = "esql"
query = '''
from .ds-logs-endpoint.events.process-default-*
| where event.action == "start" and process.code_signature.subject_name like "Microsoft*" and process.parent.name in ("winword.exe", "WINWORD.EXE", "EXCEL.EXE", "excel.exe")
| eval process_path = replace(process.executable, """[cC]:\\[uU][sS][eE][rR][sS]\\[a-zA-Z0-9\.\-\_\$]+\\""", "C:\\\\users\\\\user\\\\")
| stats cc = count(*) by process_path, process.parent.name | where cc <= 5
'''
[[rule.threat]]
framework = "MITRE ATT&CK"
[rule.threat.tactic]
id = "TA0011"
name = "Command and Control"
reference = "https://attack.mitre.org/tactics/TA0011/"
+3 -130
View File
@@ -13,6 +13,7 @@ from collections import defaultdict
from pathlib import Path
import eql.ast
from marshmallow import ValidationError
from semver import Version
@@ -29,8 +30,7 @@ from detection_rules.rule import (QueryRuleData, QueryValidator,
from detection_rules.rule_loader import FILE_PATTERN
from detection_rules.rule_validators import EQLValidator, KQLValidator
from detection_rules.schemas import definitions, get_stack_schemas
from detection_rules.utils import (INTEGRATION_RULE_DIR, PatchedTemplate,
get_path, load_etc_dump)
from detection_rules.utils import INTEGRATION_RULE_DIR, PatchedTemplate, get_path, load_etc_dump
from detection_rules.version_lock import default_version_lock
from rta import get_available_tests
@@ -666,7 +666,7 @@ class TestRuleMetadata(BaseRuleTest):
"f3e22c8b-ea47-45d1-b502-b57b6de950b3"
]
if any([re.search("|".join(non_dataset_packages), i, re.IGNORECASE)
for i in rule.contents.data.index]):
for i in rule.contents.data.get('index') or []]):
if not rule.contents.metadata.integration and rule.id not in ignore_ids and \
rule.contents.data.type not in definitions.MACHINE_LEARNING:
err_msg = f'substrings {non_dataset_packages} found in '\
@@ -1182,35 +1182,6 @@ class TestRiskScoreMismatch(BaseRuleTest):
self.fail(err_msg)
class TestEndpointQuery(BaseRuleTest):
"""Test endpoint-specific rules."""
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"),
"Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0")
def test_os_and_platform_in_query(self):
"""Test that all endpoint rules have an os defined and linux includes platform."""
for rule in self.production_rules:
if not rule.contents.data.get('language') in ('eql', 'kuery'):
continue
if rule.path.parent.name not in ('windows', 'macos', 'linux'):
# skip cross-platform for now
continue
ast = rule.contents.data.ast
fields = [str(f) for f in ast if isinstance(f, (kql.ast.Field, eql.ast.Field))]
err_msg = f'{self.rule_str(rule)} missing required field for endpoint rule'
if 'host.os.type' not in fields:
# Exception for Forwarded Events which contain Windows-only fields.
if rule.path.parent.name == 'windows' and not any(field.startswith('winlog.') for field in fields):
self.assertIn('host.os.type', fields, err_msg)
# going to bypass this for now
# if rule.path.parent.name == 'linux':
# err_msg = f'{self.rule_str(rule)} missing required field for linux endpoint rule'
# self.assertIn('host.os.platform', fields, err_msg)
class TestNoteMarkdownPlugins(BaseRuleTest):
"""Test if a guide containing Osquery Plugin syntax contains the version note."""
@@ -1334,101 +1305,3 @@ class TestAlertSuppression(BaseRuleTest):
if fld not in schema.keys():
self.fail(f"{self.rule_str(rule)} alert suppression field {fld} not \
found in ECS, Beats, or non-ecs schemas")
class TestNewTerms(BaseRuleTest):
"""Test new term rules."""
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
def test_history_window_start(self):
"""Test new terms history window start field."""
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
# validate history window start field exists and is correct
assert rule.contents.data.new_terms.history_window_start, \
"new terms field found with no history_window_start field defined"
assert rule.contents.data.new_terms.history_window_start[0].field == "history_window_start", \
f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
def test_new_terms_field_exists(self):
# validate new terms and history window start fields are correct
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
assert rule.contents.data.new_terms.field == "new_terms_fields", \
f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
def test_new_terms_fields(self):
"""Test new terms fields are schema validated."""
# ecs validation
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
meta = rule.contents.metadata
feature_min_stack = Version.parse('8.4.0')
current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True)
min_stack_version = Version.parse(meta.get("min_stack_version")) if \
meta.get("min_stack_version") else None
min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \
current_package_version else min_stack_version
assert min_stack_version >= feature_min_stack, \
f"New Terms rule types only compatible with {feature_min_stack}+"
ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs']
beats_version = get_stack_schemas()[str(min_stack_version)]['beats']
# checks if new terms field(s) are in ecs, beats non-ecs or integration schemas
queryvalidator = QueryValidator(rule.contents.data.query)
_, _, schema = queryvalidator.get_beats_schema([], beats_version, ecs_version)
integration_manifests = load_integrations_manifests()
integration_schemas = load_integrations_schemas()
integration_tags = meta.get("integration")
if integration_tags:
for tag in integration_tags:
latest_tag_compat_ver, _ = find_latest_compatible_version(
package=tag,
integration="",
rule_stack_version=min_stack_version,
packages_manifest=integration_manifests)
if latest_tag_compat_ver:
integration_schema = integration_schemas[tag][latest_tag_compat_ver]
for policy_template in integration_schema.keys():
schema.update(**integration_schemas[tag][latest_tag_compat_ver][policy_template])
for new_terms_field in rule.contents.data.new_terms.value:
assert new_terms_field in schema.keys(), \
f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
def test_new_terms_max_limit(self):
"""Test new terms max limit."""
# validates length of new_terms to stack version - https://github.com/elastic/kibana/issues/142862
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
meta = rule.contents.metadata
feature_min_stack = Version.parse('8.4.0')
feature_min_stack_extended_fields = Version.parse('8.6.0')
current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True)
min_stack_version = Version.parse(meta.get("min_stack_version")) if \
meta.get("min_stack_version") else None
min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \
current_package_version else min_stack_version
if min_stack_version >= feature_min_stack and \
min_stack_version < feature_min_stack_extended_fields:
assert len(rule.contents.data.new_terms.value) == 1, \
f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
def test_new_terms_fields_unique(self):
"""Test new terms fields are unique."""
# validate fields are unique
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
assert len(set(rule.contents.data.new_terms.value)) == len(rule.contents.data.new_terms.value), \
f"new terms fields values are not unique - {rule.contents.data.new_terms.value}"
+21
View File
@@ -0,0 +1,21 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import unittest
from .base import BaseRuleTest
from detection_rules.misc import get_default_config
# from detection_rules.remote_validation import RemoteValidator
@unittest.skipIf(get_default_config() is None, 'Skipping remote validation due to missing config')
class TestRemoteRules(BaseRuleTest):
"""Test rules against a remote Elastic stack instance."""
# def test_esql_rules(self):
# """Temporarily explicitly test all ES|QL rules remotely pending parsing lib."""
# esql_rules = [r for r in self.all_rules if r.contents.data.type == 'esql']
# rv = RemoteValidator(parse_config=True)
# rv.validate_rules(esql_rules)
+186
View File
@@ -0,0 +1,186 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import unittest
from copy import deepcopy
from pathlib import Path
import eql.ast
from semver import Version
import kql
from detection_rules.integrations import (
find_latest_compatible_version, load_integrations_manifests, load_integrations_schemas
)
from detection_rules.misc import load_current_package_version
from detection_rules.packaging import current_stack_version
from detection_rules.rule import QueryValidator
from detection_rules.rule_loader import RuleCollection
from detection_rules.schemas import get_stack_schemas
from detection_rules.utils import get_path, load_rule_contents
from .base import BaseRuleTest
PACKAGE_STACK_VERSION = Version.parse(current_stack_version(), optional_minor_and_patch=True)
class TestEndpointQuery(BaseRuleTest):
"""Test endpoint-specific rules."""
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"),
"Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0")
def test_os_and_platform_in_query(self):
"""Test that all endpoint rules have an os defined and linux includes platform."""
for rule in self.production_rules:
if not rule.contents.data.get('language') in ('eql', 'kuery'):
continue
if rule.path.parent.name not in ('windows', 'macos', 'linux'):
# skip cross-platform for now
continue
ast = rule.contents.data.ast
fields = [str(f) for f in ast if isinstance(f, (kql.ast.Field, eql.ast.Field))]
err_msg = f'{self.rule_str(rule)} missing required field for endpoint rule'
if 'host.os.type' not in fields:
# Exception for Forwarded Events which contain Windows-only fields.
if rule.path.parent.name == 'windows' and not any(field.startswith('winlog.') for field in fields):
self.assertIn('host.os.type', fields, err_msg)
# going to bypass this for now
# if rule.path.parent.name == 'linux':
# err_msg = f'{self.rule_str(rule)} missing required field for linux endpoint rule'
# self.assertIn('host.os.platform', fields, err_msg)
class TestNewTerms(BaseRuleTest):
"""Test new term rules."""
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
def test_history_window_start(self):
"""Test new terms history window start field."""
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
# validate history window start field exists and is correct
assert rule.contents.data.new_terms.history_window_start, \
"new terms field found with no history_window_start field defined"
assert rule.contents.data.new_terms.history_window_start[0].field == "history_window_start", \
f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
def test_new_terms_field_exists(self):
# validate new terms and history window start fields are correct
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
assert rule.contents.data.new_terms.field == "new_terms_fields", \
f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
def test_new_terms_fields(self):
"""Test new terms fields are schema validated."""
# ecs validation
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
meta = rule.contents.metadata
feature_min_stack = Version.parse('8.4.0')
current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True)
min_stack_version = Version.parse(meta.get("min_stack_version")) if \
meta.get("min_stack_version") else None
min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \
current_package_version else min_stack_version
assert min_stack_version >= feature_min_stack, \
f"New Terms rule types only compatible with {feature_min_stack}+"
ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs']
beats_version = get_stack_schemas()[str(min_stack_version)]['beats']
# checks if new terms field(s) are in ecs, beats non-ecs or integration schemas
queryvalidator = QueryValidator(rule.contents.data.query)
_, _, schema = queryvalidator.get_beats_schema([], beats_version, ecs_version)
integration_manifests = load_integrations_manifests()
integration_schemas = load_integrations_schemas()
integration_tags = meta.get("integration")
if integration_tags:
for tag in integration_tags:
latest_tag_compat_ver, _ = find_latest_compatible_version(
package=tag,
integration="",
rule_stack_version=min_stack_version,
packages_manifest=integration_manifests)
if latest_tag_compat_ver:
integration_schema = integration_schemas[tag][latest_tag_compat_ver]
for policy_template in integration_schema.keys():
schema.update(**integration_schemas[tag][latest_tag_compat_ver][policy_template])
for new_terms_field in rule.contents.data.new_terms.value:
assert new_terms_field in schema.keys(), \
f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
def test_new_terms_max_limit(self):
"""Test new terms max limit."""
# validates length of new_terms to stack version - https://github.com/elastic/kibana/issues/142862
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
meta = rule.contents.metadata
feature_min_stack = Version.parse('8.4.0')
feature_min_stack_extended_fields = Version.parse('8.6.0')
current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True)
min_stack_version = Version.parse(meta.get("min_stack_version")) if \
meta.get("min_stack_version") else None
min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \
current_package_version else min_stack_version
if feature_min_stack <= min_stack_version < feature_min_stack_extended_fields:
assert len(rule.contents.data.new_terms.value) == 1, \
f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}"
@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
def test_new_terms_fields_unique(self):
"""Test new terms fields are unique."""
# validate fields are unique
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
assert len(set(rule.contents.data.new_terms.value)) == len(rule.contents.data.new_terms.value), \
f"new terms fields values are not unique - {rule.contents.data.new_terms.value}"
class TestESQLRules(BaseRuleTest):
"""Test ESQL Rules."""
def run_esql_test(self, esql_query, expectation, message):
"""Test that the query validation is working correctly."""
rc = RuleCollection()
file_path = Path(get_path("tests", "data", "command_control_dummy_production_rule.toml"))
original_production_rule = load_rule_contents(file_path)
# Test that a ValidationError is raised if the query doesn't match the schema
production_rule = deepcopy(original_production_rule)[0]
production_rule["rule"]["query"] = esql_query
expectation.match_expr = message
with expectation:
rc.load_dict(production_rule)
def test_esql_queries(self):
"""Test ESQL queries."""
# test_cases = [
# # invalid queries
# ('from .ds-logs-endpoint.events.process-default-* | wheres process.name like "Microsoft*"',
# pytest.raises(marshmallow.exceptions.ValidationError), r"ESQL query failed"),
# ('from .ds-logs-endpoint.events.process-default-* | where process.names like "Microsoft*"',
# pytest.raises(marshmallow.exceptions.ValidationError), r"ESQL query failed"),
#
# # valid queries
# ('from .ds-logs-endpoint.events.process-default-* | where process.name like "Microsoft*"',
# does_not_raise(), None),
# ]
# for esql_query, expectation, message in test_cases:
# self.run_esql_test(esql_query, expectation, message)