Add unit tests

Co-Authored-By: Justin Ibarra <brokensound77@users.noreply.github.com>
This commit is contained in:
Ross Wolf
2020-06-29 23:19:23 -06:00
parent 3b305d3003
commit d51474f0a7
12 changed files with 1017 additions and 0 deletions
+65
View File
@@ -0,0 +1,65 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Detection Rules tests."""
import glob
import json
import os
from detection_rules.utils import combine_sources
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(CURRENT_DIR, 'data')
TP_DIR = os.path.join(DATA_DIR, 'true_positives')
FP_DIR = os.path.join(DATA_DIR, 'false_positives')
def get_fp_dirs():
"""Get a list of fp dir names."""
return glob.glob(os.path.join(FP_DIR, '*'))
def get_fp_data_files():
"""get FP data files by fp dir name."""
data = {}
for fp_dir in get_fp_dirs():
fp_dir_name = os.path.basename(fp_dir)
relative_dir_name = os.path.join('false_positives', fp_dir_name)
data[fp_dir_name] = combine_sources(*get_data_files(relative_dir_name).values())
return data
def get_data_files_list(*folder, ext='jsonl', recursive=False):
"""Get TP or FP file list."""
folder = os.path.sep.join(folder)
data_dir = [DATA_DIR, folder]
if recursive:
data_dir.append('**')
data_dir.append('*.{}'.format(ext))
return glob.glob(os.path.join(*data_dir), recursive=recursive)
def get_data_files(*folder, ext='jsonl', recursive=False):
"""Get data from data files."""
data_files = {}
for data_file in get_data_files_list(*folder, ext=ext, recursive=recursive):
with open(data_file, 'r') as f:
file_name = os.path.splitext(os.path.basename(data_file))[0]
if ext == 'jsonl':
data = f.readlines()
data_files[file_name] = [json.loads(d) for d in data]
else:
data_files[file_name] = json.load(f)
return data_files
def get_data_file(*folder):
file = os.path.join(DATA_DIR, os.path.sep.join(folder))
if os.path.exists(file):
with open(file, 'r') as f:
return json.load(f)
+5
View File
@@ -0,0 +1,5 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""KQL unit tests."""
+52
View File
@@ -0,0 +1,52 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
import unittest
import kql
class TestEql2Kql(unittest.TestCase):
def validate(self, kql_source, eql_source):
self.assertEqual(kql_source, str(kql.from_eql(eql_source)))
def test_field_equals(self):
self.validate("field:value", "field == 'value'")
self.validate("field:-1", "field == -1")
self.validate("field:1.1", "field == 1.1")
self.validate("field:true", "field == true")
self.validate("field:false", "field == false")
self.validate("field:*", "field != null")
self.validate("not field:*", "field == null")
def test_field_inequality(self):
self.validate("field < value", "field < 'value'")
self.validate("field > -1", "field > -1")
self.validate("field <= 1.1", "field <= 1.1")
self.validate("field >= 0", "field >= 0")
def test_or_query(self):
self.validate("field:value or field2:value2", "field == 'value' or field2 == 'value2'")
def test_and_query(self):
self.validate("field:value and field2:value2", "field == 'value' and field2 == 'value2'")
def test_not_query(self):
self.validate("not field:value", "field != 'value'")
self.validate("not (field:value and field2:value2)", "not (field = 'value' and field2 = 'value2')")
def test_boolean_precedence(self):
self.validate("a:1 or b:2 and c:3", "a == 1 or (b == 2 and c == 3)")
self.validate("a:1 and (b:2 or c:3)", "a == 1 and (b == 2 or c == 3)")
self.validate("a:1 or not b:2 and c:3", "a == 1 or (b != 2 and c == 3)")
def test_list_of_values(self):
self.validate("a:(0 or 1 or 2 or 3)", "a in (0,1,2,3)")
self.validate("a:(0 or 3 or 1 and 2)", "a == 0 or a == 1 and a == 2 or a == 3")
self.validate("a:(0 or 1 and 2 or 3 and 4)", "a == 0 or a == 1 and a == 2 or (a == 3 and a == 4)")
def test_ip_checks(self):
self.validate("dest:192.168.255.255", "dest == '192.168.255.255'")
self.validate("dest:192.168.0.0/16", "cidrMatch(dest, '192.168.0.0/16')")
self.validate("dest:192.168.0.0/16", "cidrMatch(dest, '192.168.0.0/16')")
+122
View File
@@ -0,0 +1,122 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
import unittest
import kql
class EvaluatorTests(unittest.TestCase):
document = {
"number": 1,
"boolean": True,
"ip": "192.168.16.3",
"string": "hello world",
"string_list": ["hello world", "example"],
"number_list": [1, 2, 3],
"boolean_list": [True, False],
"structured": [
{
"a": [
{"b": 1}
]
}
],
}
def evaluate(self, source_text, document=None):
if document is None:
document = self.document
evaluator = kql.get_evaluator(source_text, optimize=False)
return evaluator(document)
def test_single_value(self):
self.assertTrue(self.evaluate('number:1'))
self.assertTrue(self.evaluate('number:"1"'))
self.assertTrue(self.evaluate('boolean:true'))
self.assertTrue(self.evaluate('string:"hello world"'))
self.assertFalse(self.evaluate('number:0'))
self.assertFalse(self.evaluate('boolean:false'))
self.assertFalse(self.evaluate('string:"missing"'))
def test_list_value(self):
self.assertTrue(self.evaluate('number_list:1'))
self.assertTrue(self.evaluate('number_list:2'))
self.assertTrue(self.evaluate('number_list:3'))
self.assertTrue(self.evaluate('boolean_list:true'))
self.assertTrue(self.evaluate('boolean_list:false'))
self.assertTrue(self.evaluate('string_list:"hello world"'))
self.assertTrue(self.evaluate('string_list:example'))
self.assertFalse(self.evaluate('number_list:4'))
self.assertFalse(self.evaluate('string_list:"missing"'))
def test_and_values(self):
self.assertTrue(self.evaluate('number_list:(1 and 2)'))
self.assertTrue(self.evaluate('boolean_list:(false and true)'))
self.assertFalse(self.evaluate('string:("missing" and "hello world")'))
self.assertFalse(self.evaluate('number:(0 and 1)'))
self.assertFalse(self.evaluate('boolean:(false and true)'))
def test_not_value(self):
self.assertTrue(self.evaluate('number_list:1'))
self.assertFalse(self.evaluate('not number_list:1'))
self.assertFalse(self.evaluate('number_list:(not 1)'))
def test_or_values(self):
self.assertTrue(self.evaluate('number:(0 or 1)'))
self.assertTrue(self.evaluate('number:(1 or 2)'))
self.assertTrue(self.evaluate('boolean:(false or true)'))
self.assertTrue(self.evaluate('string:("missing" or "hello world")'))
self.assertFalse(self.evaluate('number:(0 or 3)'))
def test_and_expr(self):
self.assertTrue(self.evaluate('number:1 and boolean:true'))
self.assertFalse(self.evaluate('number:1 and boolean:false'))
def test_or_expr(self):
self.assertTrue(self.evaluate('number:1 or boolean:false'))
self.assertFalse(self.evaluate('number:0 or boolean:false'))
def test_range(self):
self.assertTrue(self.evaluate('number < 2'))
self.assertFalse(self.evaluate('number > 2'))
def test_cidr_match(self):
self.assertTrue(self.evaluate('ip:192.168.0.0/16'))
self.assertFalse(self.evaluate('ip:10.0.0.0/8'))
def test_quoted_wildcard(self):
self.assertFalse(self.evaluate('string:"*"'))
def test_wildcard(self):
self.assertTrue(self.evaluate('string:hello*'))
self.assertTrue(self.evaluate('string:*world'))
self.assertFalse(self.evaluate('string:foobar*'))
def test_field_exists(self):
self.assertTrue(self.evaluate('number:*'))
self.assertTrue(self.evaluate('boolean:*'))
self.assertTrue(self.evaluate('ip:*'))
self.assertTrue(self.evaluate('string:*'))
self.assertTrue(self.evaluate('string_list:*'))
self.assertTrue(self.evaluate('number_list:*'))
self.assertTrue(self.evaluate('boolean_list:*'))
self.assertFalse(self.evaluate('a:*'))
def test_flattening(self):
self.assertTrue(self.evaluate("structured.a.b:*"))
self.assertTrue(self.evaluate("structured.a.b:1"))
self.assertFalse(self.evaluate("structured.a.b:2"))
+94
View File
@@ -0,0 +1,94 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
import unittest
import eql
import kql
class TestKql2Eql(unittest.TestCase):
def validate(self, kql_source, eql_source, schema=None):
self.assertEqual(kql.to_eql(kql_source, schema=schema), eql.parse_expression(eql_source))
def test_field_equals(self):
self.validate("field:value", "field == 'value'")
self.validate("field:-1", "field == -1")
self.validate("field:1.0", "field == 1.0")
self.validate("field:true", "field == true")
self.validate("field:false", "field == false")
self.validate("not field:*", "field == null")
self.validate("field:*", "field != null")
def test_field_inequality(self):
self.validate("field < value", "field < 'value'")
self.validate("field > -1", "field > -1")
self.validate("field <= 1.0", "field <= 1.0")
self.validate("field >= 0", "field >= 0")
def test_or_query(self):
self.validate("field:value or field2:value2", "field == 'value' or field2 == 'value2'")
def test_and_query(self):
self.validate("field:value and field2:value2", "field == 'value' and field2 == 'value2'")
def test_nested_query(self):
with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"):
kql.to_eql("field:{outer:1 and middle:{inner:2}}")
def test_not_query(self):
self.validate("not field:value", "field != 'value'")
self.validate("not (field:value and field2:value2)", "not (field = 'value' and field2 = 'value2')")
def test_boolean_precedence(self):
self.validate("a:1 or (b:2 and c:3)", "a == 1 or (b == 2 and c == 3)")
self.validate("a:1 or b:2 and c:3", "a == 1 or (b == 2 and c == 3)")
self.validate("a:1 or not b:2 and c:3", "a == 1 or (b != 2 and c == 3)")
def test_list_of_values(self):
self.validate("a:(0 or 1 or 2 or 3)", "a in (0,1,2,3)")
self.validate("a:(0 or 1 and 2 or 3)", "a == 0 or a == 1 and a == 2 or a == 3")
self.validate("a:(0 or 1 and 2 or (3 and 4))", "a == 0 or a == 1 and a == 2 or (a == 3 and a == 4)")
def test_lone_value(self):
for value in ["1", "-1.4", "true", "\"string test\""]:
with self.assertRaisesRegex(kql.KqlParseError, "Value not tied to field"):
kql.to_eql(value)
def test_schema(self):
schema = {
"top": "nested",
"top.keyword": "keyword",
"top.text": "text",
"top.middle": "nested",
"top.middle.bool": "boolean",
"top.numL": "long",
"top.numF": "long",
"dest": "ip",
}
self.validate("top.numF : 1", "top.numF == 1", schema=schema)
self.validate("top.numF : \"1\"", "top.numF == 1", schema=schema)
self.validate("top.keyword : 1", "top.keyword == '1'", schema=schema)
self.validate("top.text : \"hello\"", "top.text == 'hello'", schema=schema)
self.validate("top.keyword : \"hello\"", "top.keyword == 'hello'", schema=schema)
self.validate("top.text : 1 ", "top.text == '1'", schema=schema)
self.validate("dest:192.168.255.255", "dest == '192.168.255.255'", schema=schema)
self.validate("dest:192.168.0.0/16", "cidrMatch(dest, '192.168.0.0/16')", schema=schema)
self.validate("dest:\"192.168.0.0/16\"", "cidrMatch(dest, '192.168.0.0/16')", schema=schema)
with self.assertRaisesRegex(kql.KqlParseError, r"Value doesn't match top.middle's type: nested"):
kql.to_eql("top.middle : 1", schema=schema)
with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"):
kql.to_eql("top:{keyword : 1}", schema=schema)
with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"):
kql.to_eql("top:{middle:{bool: true}}", schema=schema)
invalid_ips = ["192.168.0.256", "192.168.0.256/33", "1", "\"1\""]
for ip in invalid_ips:
with self.assertRaisesRegex(kql.KqlParseError, r"Value doesn't match dest's type: ip"):
kql.to_eql("dest:{ip}".format(ip=ip), schema=schema)
+79
View File
@@ -0,0 +1,79 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
import unittest
import kql
class LintTests(unittest.TestCase):
def validate(self, source, linted, *args):
self.assertEqual(kql.lint(source), linted, *args)
def test_lint_field(self):
self.validate("a : b", "a:b")
self.validate("\"a\": b", "a:b")
self.validate("a : \"b\"", "a:b")
self.validate("a : (b)", "a:b")
self.validate("a:1.234", "a:1.234")
self.validate("a:\"1.234\"", "a:1.234")
def test_upper_tokens(self):
queries = [
"a:b AND c:d",
"a:b OR c:d",
"NOT a:b",
"a:(b OR c)",
"a:(b AND c)",
"a:(NOT b)",
]
for q in queries:
with self.assertRaises(kql.KqlParseError):
kql.parse(q)
def test_lint_precedence(self):
self.validate("a:b or (c:d and e:f)", "a:b or c:d and e:f")
self.validate("(a:b and (c:d or e:f))", "a:b and (c:d or e:f)")
def test_extract_not(self):
self.validate("a:(not b)", "not a:b")
def test_merge_fields(self):
self.validate("a:b or a:c", "a:(b or c)")
self.validate("a:b or a:(c or d)", "a:(b or c or d)")
self.validate("a:b or a:(c or d) or a:e", "a:(b or c or d or e)")
self.validate("a:b or a:(c and d) or x:y or a:e", "a:(b or e or c and d) or x:y", "Failed to left-align values")
self.validate("a:b and a:(c and d) or x:y or a:e", "a:(e or b and c and d) or x:y")
def test_and_not(self):
self.validate("a:b and not a:c", "a:(b and not c)")
def test_not_demorgans(self):
self.validate("not a:b and not a:c and not a:d", "not a:(b or c or d)")
self.validate("not a:b or not a:c or not a:d", "not a:(b and c and d)")
self.validate("a:(not b and not c and not d)", "not a:(b or c or d)")
self.validate("a:(not b or not c or not d)", "not a:(b and c and d)")
def test_not_or(self):
self.validate("not (a:1 or a:2)", "not a:(1 or 2)")
def test_mixed_demorgans(self):
self.validate("a:(b and not c and not d)", "a:(b and not (c or d))")
self.validate("a:(b or not c or not d or not e)", "a:(b or not (c and d and e))")
self.validate("a:((b or not c or not d) and e)", "a:(e and (b or not (c and d)))")
def test_double_negate(self):
self.validate("not (not a:b)", "a:b")
self.validate("a:(not (not b))", "a:b")
self.validate("not (a:(not b))", "a:b")
self.validate("not (not (a:b or c:d))", "a:b or c:d")
self.validate("not (not (a:(not b) or c:(not d)))", "not a:b or not c:d")
def test_ip(self):
self.validate("a:ff02\\:\\:fb", "a:\"ff02::fb\"")
def test_compound(self):
self.validate("a:1 and b:2 and not (c:3 or c:4)", "a:1 and b:2 and not c:(3 or 4)")
+61
View File
@@ -0,0 +1,61 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
import unittest
import kql
from kql.ast import (
Field,
FieldComparison,
String,
Number,
Exists,
)
class ParserTests(unittest.TestCase):
def validate(self, source, tree, *args, **kwargs):
kwargs.setdefault("optimize", False)
self.assertEqual(kql.parse(source, *args, **kwargs), tree)
def test_keyword(self):
schema = {
"a.text": "text",
"a.keyword": "keyword",
"b": "long",
}
self.validate('a.text:hello', FieldComparison(Field("a.text"), String("hello")), schema=schema)
self.validate('a.keyword:hello', FieldComparison(Field("a.keyword"), String("hello")), schema=schema)
self.validate('a.text:"hello"', FieldComparison(Field("a.text"), String("hello")), schema=schema)
self.validate('a.keyword:"hello"', FieldComparison(Field("a.keyword"), String("hello")), schema=schema)
self.validate('a.text:1', FieldComparison(Field("a.text"), String("1")), schema=schema)
self.validate('a.keyword:1', FieldComparison(Field("a.keyword"), String("1")), schema=schema)
self.validate('a.text:"1"', FieldComparison(Field("a.text"), String("1")), schema=schema)
self.validate('a.keyword:"1"', FieldComparison(Field("a.keyword"), String("1")), schema=schema)
def test_conversion(self):
schema = {"num": "long", "text": "text"}
self.validate('num:1', FieldComparison(Field("num"), Number(1)), schema=schema)
self.validate('num:"1"', FieldComparison(Field("num"), Number(1)), schema=schema)
self.validate('text:1', FieldComparison(Field("text"), String("1")), schema=schema)
self.validate('text:"1"', FieldComparison(Field("text"), String("1")), schema=schema)
def test_list_equals(self):
self.assertEqual(kql.parse("a:(1 or 2)", optimize=False), kql.parse("a:(2 or 1)", optimize=False))
def test_number_exists(self):
self.assertEqual(kql.parse("foo:*", schema={"foo": "long"}), FieldComparison(Field("foo"), Exists()))
def test_number_wildcard_fail(self):
with self.assertRaises(kql.KqlParseError):
kql.parse("foo:*wc", schema={"foo": "long"})
with self.assertRaises(kql.KqlParseError):
kql.parse("foo:wc*", schema={"foo": "long"})
+145
View File
@@ -0,0 +1,145 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Test that all rules have valid metadata and syntax."""
import json
import os
import re
import sys
import unittest
import jsonschema
import kql
import toml
import pytoml
from rta import get_ttp_names
from detection_rules import rule_loader
from detection_rules.utils import load_etc_dump
from detection_rules.rule import Rule
class TestValidRules(unittest.TestCase):
"""Test that all detection rules load properly without duplicates."""
def test_schema_and_dupes(self):
"""Ensure that every rule matches the schema and there are no duplicates."""
rule_files = rule_loader.load_rule_files()
self.assertGreaterEqual(len(rule_files), 1, 'No rules were loaded from rules directory!')
def test_all_rule_files(self):
"""Ensure that every rule file can be loaded and validate against schema."""
rules = []
for file_name, contents in rule_loader.load_rule_files().items():
try:
rule = Rule(file_name, contents)
rules.append(rule)
except (pytoml.TomlError, toml.TomlDecodeError) as e:
print("TOML error when parsing rule file \"{}\"".format(os.path.basename(file_name)), file=sys.stderr)
raise e
except jsonschema.ValidationError as e:
print("Schema error when parsing rule file \"{}\"".format(os.path.basename(file_name)), file=sys.stderr)
raise e
def test_rule_loading(self):
"""Ensure that all rule queries have ecs version."""
rule_loader.load_rules().values()
def test_file_names(self):
"""Test that the file names meet the requirement."""
file_pattern = rule_loader.FILE_PATTERN
self.assertIsNone(re.match(file_pattern, 'NotValidRuleFile.toml'),
'Incorrect pattern for verifying rule names: {}'.format(file_pattern))
self.assertIsNone(re.match(file_pattern, 'still_not_a_valid_file_name.not_json'),
'Incorrect pattern for verifying rule names: {}'.format(file_pattern))
for rule_file in rule_loader.load_rule_files().keys():
self.assertIsNotNone(re.match(file_pattern, os.path.basename(rule_file)),
'Invalid file name for {}'.format(rule_file))
def test_all_rules_as_rule_schema(self):
"""Ensure that every rule file validates against the rule schema."""
for file_name, contents in rule_loader.load_rule_files().items():
rule = Rule(file_name, contents)
rule.validate(as_rule=True)
def test_all_rules_tuned(self):
"""Ensure that every rule file validates against the rule schema."""
for file_name, contents in rule_loader.load_rule_files().items():
rule = Rule(file_name, contents, tune=True)
rule.validate(as_rule=True)
def test_all_rule_queries_optimized(self):
"""Ensure that every rule query is in optimized form."""
for file_name, contents in rule_loader.load_rule_files().items():
rule = Rule(file_name, contents)
if rule.query and rule.contents['language'] == 'kuery':
tree = kql.parse(rule.query, optimize=False)
optimized = tree.optimize(recursive=True)
err_message = '\nQuery not optimized for rule: {} - {}\nExpected: {}\nActual: {}'.format(
rule.name, rule.id, optimized, rule.query)
self.assertEqual(tree, optimized, err_message)
def test_ecs_version_in_query(self):
"""Ensure that all rule queries have ecs version."""
# rule_loader.reset()
# rules = list(rule_loader.load_rules().values())
#
# for rule in rules:
# ecs_ver = rule.metadata.get('ecs_version')
# if ecs_ver:
# self.assertTrue('ecs.version:{}'.format(ecs_ver) in rule.query,
# 'ecs_version specified but missing from query')
def test_rules_lint_integrity(self):
"""Verify that linting is not compromising integrity of a rule."""
'''def validate(source, linted, *args):
self.assertEqual(kql.lint(source), linted, *args)
rules = rule_loader.load_rules().values()
for rule in rules:
try:
linted = eql2kql.convert(kql2eql.parse(rule.query).render())
validate(rule.query, linted, 'Linting improperly modified the query from: \n\t{} \nto \n\t{}'.format(
rule.query, linted))
except Exception as e:
raise Exception('{} - {}:\n{}'.format(rule.name, rule.query, e))'''
def test_no_unrequired_defaults(self):
"""Test that values that are not required in the schema are not set with default values."""
rules_with_hits = {}
for file_name, contents in rule_loader.load_rule_files().items():
rule = Rule(file_name, contents)
default_matches = rule_loader.find_unneeded_defaults(rule)
if default_matches:
rules_with_hits['{} - {}'.format(rule.name, rule.id)] = default_matches
error_msg = 'The following rules have unnecessary default values set: \n{}'.format(
json.dumps(rules_with_hits, indent=2))
self.assertDictEqual(rules_with_hits, {}, error_msg)
@rule_loader.mock_loader
def test_production_rules_have_rta(self):
"""Ensure that all production rules have RTAs."""
mappings = load_etc_dump('rule-mapping.yml')
ttp_names = get_ttp_names()
for rule in rule_loader.get_production_rules():
if rule.type == 'query' and rule.id in mappings:
matching_rta = mappings[rule.id].get('rta_name')
self.assertIsNotNone(matching_rta, "Rule {} ({}) does not have RTAs".format(rule.name, rule.id))
rta_name, ext = os.path.splitext(matching_rta)
if rta_name not in ttp_names:
self.fail("{} ({}) references unknown RTA: {}".format(rule.name, rule.id, rta_name))
+69
View File
@@ -0,0 +1,69 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Test that all rules appropriately match against expected data sets."""
import copy
import unittest
import warnings
from . import get_data_files, get_fp_data_files
from detection_rules import rule_loader
from detection_rules.utils import combine_sources, evaluate, load_etc_dump
class TestMappings(unittest.TestCase):
"""Test that all rules appropriately match against expected data sets."""
FP_FILES = get_fp_data_files()
RULES = rule_loader.load_rules().values()
def evaluate(self, documents, rule, expected, msg):
"""KQL engine to evaluate."""
filtered = evaluate(rule, documents)
self.assertEqual(expected, len(filtered), msg)
return filtered
def test_true_positives(self):
"""Test that expected results return against true positives."""
mismatched_ecs = []
mappings = load_etc_dump('rule-mapping.yml')
for rule in rule_loader.get_production_rules():
if rule.type == 'query' and rule.contents['language'] == 'kuery':
if rule.id not in mappings:
continue
mapping = mappings[rule.id]
expected = mapping['count']
sources = mapping.get('sources')
rta_file = mapping['rta_name']
# ensure sources is defined and not empty; schema allows it to not be set since 'pending' bypasses
self.assertTrue(sources, 'No sources defined for: {} - {} '.format(rule.id, rule.name))
msg = 'Expected TP results did not match for: {} - {}'.format(rule.id, rule.name)
data_files = [get_data_files('true_positives', rta_file).get(s) for s in sources]
data_file = combine_sources(*data_files)
results = self.evaluate(data_file, rule, expected, msg)
ecs_versions = set([r.get('ecs', {}).get('version') for r in results])
rule_ecs = set(rule.metadata.get('ecs_version').copy())
if not ecs_versions & rule_ecs:
msg = '{} - {} ecs_versions ({}) not in source data versions ({})'.format(
rule.id, rule.name, ', '.join(rule_ecs), ', '.join(ecs_versions))
mismatched_ecs.append(msg)
if mismatched_ecs:
msg = 'Rules detected with source data from ecs versions not listed within the rule: \n{}'.format(
'\n'.join(mismatched_ecs))
warnings.warn(msg)
def test_false_positives(self):
"""Test that expected results return against false positives."""
for rule in rule_loader.get_production_rules():
if rule.type == 'query' and rule.contents['language'] == 'kuery':
for fp_name, merged_data in get_fp_data_files().items():
msg = 'Unexpected FP match for: {} - {}, against: {}'.format(rule.id, rule.name, fp_name)
self.evaluate(copy.deepcopy(merged_data), rule, 0, msg)
+149
View File
@@ -0,0 +1,149 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Test that the packages are built correctly."""
import unittest
import uuid
import yaml
from detection_rules import rule_loader
from detection_rules.packaging import Package, PACKAGE_FILE
class TestPackages(unittest.TestCase):
"""Test package building and saving."""
@staticmethod
def get_test_rule(version=1, count=1):
def get_rule_contents():
contents = {
"author": ["Elastic"],
"description": "test description",
"language": "kuery",
"license": "Elastic License",
"name": "test rule",
"query": "process.name:test.query",
"risk_score": 21,
"rule_id": str(uuid.uuid4()),
"severity": "low",
"type": "query"
}
return contents
rules = [rule_loader.Rule('test.toml', get_rule_contents()) for i in range(count)]
version_info = {
rule.id: {
'rule_name': rule.name,
'sha256': rule.get_hash(),
'version': version
} for rule in rules
}
return rules, version_info
def test_package_loader_production_config(self):
"""Test that packages are loading correctly."""
def test_package_loader_default_configs(self):
"""Test configs in etc/packages.yml."""
with open(PACKAGE_FILE) as f:
configs = yaml.safe_load(f)['package']
package = Package.from_config(configs)
for rule in package.rules:
rule.contents.pop('version')
rule.validate(as_rule=True)
@rule_loader.mock_loader
def test_package_summary(self):
"""Test the generation of the package summary."""
rules = list(rule_loader.load_rules().values())
package = Package(rules, 'test-package')
changed_rules, new_rules = package.bump_versions(save_changes=False)
package.generate_summary(changed_rules, new_rules)
def test_versioning_diffs(self):
"""Test that versioning is detecting diffs as expected."""
rules, version_info = self.get_test_rule()
package = Package(rules, 'test', current_versions=version_info)
# test versioning doesn't falsely detect changes
changed_rules, new_rules = package.changed_rules, package.new_rules
self.assertEqual(0, len(changed_rules), 'Package version bumping is improperly detecting changed rules')
self.assertEqual(0, len(new_rules), 'Package version bumping is improperly detecting new rules')
self.assertEqual(1, package.rules[0].contents['version'], 'Package version bumping unexpectedly')
# test versioning detects a new rule
package.rules[0].contents.pop('version')
changed_rules, new_rules = package.bump_versions(current_versions={})
self.assertEqual(0, len(changed_rules), 'Package version bumping is improperly detecting changed rules')
self.assertEqual(1, len(new_rules), 'Package version bumping is not detecting new rules')
self.assertEqual(1, package.rules[0].contents['version'],
'Package version bumping not setting version to 1 for new rules')
# test versioning detects a hash changes
package.rules[0].contents.pop('version')
package.rules[0].contents['query'] = 'process.name:changed.test.query'
changed_rules, new_rules = package.bump_versions(current_versions=version_info)
self.assertEqual(1, len(changed_rules), 'Package version bumping is not detecting changed rules')
self.assertEqual(0, len(new_rules), 'Package version bumping is improperly detecting new rules')
self.assertEqual(2, package.rules[0].contents['version'], 'Package version not bumping on changes')
@rule_loader.mock_loader
def test_rule_versioning(self):
"""Test that all rules are properly versioned and tracked"""
self.maxDiff = None
rules = rule_loader.load_rules().values()
original_hashes = []
post_bump_hashes = []
# test that no rules have versions defined
for rule in rules:
self.assertIsNone(rule.contents.get('version'), '{} - {}: explicitly sets a version in the rule file')
original_hashes.append(rule.get_hash())
package = Package(rules, 'test-package')
# test that all rules have versions defined
# package.bump_versions(save_changes=False)
for rule in package.rules:
self.assertGreaterEqual(rule.contents.get('version'), 1, '{} - {}: version is not being set in package')
# test that rules validate with version
for rule in package.rules:
rule.validate(versioned=True)
rule.contents.pop('version')
post_bump_hashes.append(rule.get_hash())
# test that no hashes changed as a result of the version bumps
self.assertListEqual(original_hashes, post_bump_hashes, 'Version bumping modified the hash of a rule')
def test_version_filter(self):
"""Test that version filtering is working as expected."""
msg = 'Package version filter failing'
rules, version_info = self.get_test_rule(version=1, count=3)
package = Package(rules, 'test', current_versions=version_info, min_version=2)
self.assertEqual(0, len(package.rules), msg)
rules, version_info = self.get_test_rule(version=5, count=3)
package = Package(rules, 'test', current_versions=version_info, max_version=2)
self.assertEqual(0, len(package.rules), msg)
rules, version_info = self.get_test_rule(version=2, count=3)
package = Package(rules, 'test', current_versions=version_info, min_version=1, max_version=3)
self.assertEqual(3, len(package.rules), msg)
rules, version_info = self.get_test_rule(version=1, count=3)
version = 1
for rule_id, vinfo in version_info.items():
vinfo['version'] = version
version += 1
package = Package(rules, 'test', current_versions=version_info, min_version=2, max_version=2)
self.assertEqual(1, len(package.rules), msg)
+74
View File
@@ -0,0 +1,74 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
import copy
import json
import os
import pytoml
import unittest
from detection_rules.utils import get_etc_path
from detection_rules import rule_loader
from detection_rules.rule_formatter import nested_normalize, toml_write
tmp_file = 'tmp_file.toml'
class TestRuleTomlFormatter(unittest.TestCase):
"""Test that the cutom toml formatting is not compromising the integrity of the data."""
with open(get_etc_path('test_toml.json'), 'r') as f:
test_data = json.load(f)
def compare_formatted(self, data, callback=None):
"""Compare formatted vs expected."""
try:
toml_write(copy.deepcopy(data), tmp_file)
with open(tmp_file, 'r') as f:
formatted_contents = pytoml.load(f)
# callbacks such as nested normalize leave in line breaks, so this must be manually done
query = data.get('rule', {}).get('query')
if query:
data['rule']['query'] = query.strip()
original = json.dumps(copy.deepcopy(data), sort_keys=True)
if callback:
formatted_contents = callback(formatted_contents)
# callbacks such as nested normalize leave in line breaks, so this must be manually done
query = formatted_contents.get('rule', {}).get('query')
if query:
formatted_contents['rule']['query'] = query.strip()
formatted = json.dumps(formatted_contents, sort_keys=True)
self.assertEqual(original, formatted, 'Formatting may be modifying contents')
finally:
os.remove(tmp_file)
def compare_test_data(self, test_dicts, callback=None):
"""Compare test data against expected."""
for data in test_dicts:
self.compare_formatted(data, callback=callback)
def test_normalization(self):
"""Test that normalization does not change the rule contents."""
self.compare_test_data([nested_normalize(self.test_data[0])], callback=nested_normalize)
def test_formatter_rule(self):
"""Test that formatter and encoder do not change the rule contents."""
self.compare_test_data([self.test_data[0]])
def test_formatter_deep(self):
"""Test that the data remains unchanged from formatting."""
self.compare_test_data(self.test_data[1:])
def test_format_of_all_rules(self):
"""Test all rules."""
rules = rule_loader.load_rules().values()
for rule in rules:
self.compare_formatted(rule.rule_format(formatted_query=False), callback=nested_normalize)
+102
View File
@@ -0,0 +1,102 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
"""Test util time functions."""
import random
import time
import unittest
from detection_rules.utils import normalize_timing_and_sort, cached
from detection_rules.eswrap import Events
from detection_rules.ecs import get_kql_schema
class TestTimeUtils(unittest.TestCase):
"""Test util time functions."""
@staticmethod
def get_events(timestamp_field='@timestamp'):
"""Get test data."""
date_formats = {
'epoch_millis': lambda x: int(round(time.time(), 3) + x) * 1000,
'epoch_second': lambda x: round(time.time()) + x,
'unix_micros': lambda x: time.time() + x,
'unix_millis': lambda x: round(time.time(), 3) + x,
'strict_date_optional_time': lambda x: '2020-05-13T04:36:' + str(15 + x) + '.394Z'
}
def _get_data(func):
data = [
{timestamp_field: func(0), 'foo': 'bar', 'id': 1},
{timestamp_field: func(1), 'foo': 'bar', 'id': 2},
{timestamp_field: func(2), 'foo': 'bar', 'id': 3},
{timestamp_field: func(3), 'foo': 'bar', 'id': 4},
{timestamp_field: func(4), 'foo': 'bar', 'id': 5},
{timestamp_field: func(5), 'foo': 'bar', 'id': 6}
]
random.shuffle(data)
return data
return {fmt: _get_data(func) for fmt, func in date_formats.items()}
def assert_sort(self, normalized_events, date_format):
"""Assert normalize and sort."""
order = [e['id'] for e in normalized_events]
self.assertListEqual([1, 2, 3, 4, 5, 6], order, 'Sorting failed for date_format: {}'.format(date_format))
def test_time_normalize(self):
"""Test normalize_timing_from_date_format."""
events_data = self.get_events()
for date_format, events in events_data.items():
normalized = normalize_timing_and_sort(events)
self.assert_sort(normalized, date_format)
def test_event_class_normalization(self):
"""Test that events are normalized properly within Events."""
events_data = self.get_events()
for date_format, events in events_data.items():
normalized = Events('_', {'winlogbeat': events})
self.assert_sort(normalized.events['winlogbeat'], date_format)
def test_schema_multifields(self):
"""Tests that schemas are loading multifields correctly."""
schema = get_kql_schema(version="1.4.0")
self.assertEqual(schema.get("process.name"), "keyword")
self.assertEqual(schema.get("process.name.text"), "text")
def test_caching(self):
"""Test that caching is working."""
counter = 0
@cached
def increment(*args, **kwargs):
nonlocal counter
counter += 1
return counter
self.assertEqual(increment(), 1)
self.assertEqual(increment(), 1)
self.assertEqual(increment(), 1)
self.assertEqual(increment(["hello", "world"]), 2)
self.assertEqual(increment(["hello", "world"]), 2)
self.assertEqual(increment(["hello", "world"]), 2)
self.assertEqual(increment(), 1)
self.assertEqual(increment(["hello", "world"]), 2)
self.assertEqual(increment({"hello": [("world", )]}), 3)
self.assertEqual(increment({"hello": [("world", )]}), 3)
self.assertEqual(increment(), 1)
self.assertEqual(increment(["hello", "world"]), 2)
self.assertEqual(increment({"hello": [("world", )]}), 3)
increment.clear()
self.assertEqual(increment({"hello": [("world", )]}), 4)
self.assertEqual(increment(["hello", "world"]), 5)
self.assertEqual(increment(), 6)
self.assertEqual(increment(None), 7)
self.assertEqual(increment(1), 8)