From d51474f0a7ff4bfc09d5e6de73a878fee9b19f03 Mon Sep 17 00:00:00 2001 From: Ross Wolf <31489089+rw-access@users.noreply.github.com> Date: Mon, 29 Jun 2020 23:19:23 -0600 Subject: [PATCH] Add unit tests Co-Authored-By: Justin Ibarra --- tests/__init__.py | 65 +++++++++++++++ tests/kuery/__init__.py | 5 ++ tests/kuery/test_eql2kql.py | 52 ++++++++++++ tests/kuery/test_evaluator.py | 122 ++++++++++++++++++++++++++++ tests/kuery/test_kql2eql.py | 94 +++++++++++++++++++++ tests/kuery/test_lint.py | 79 ++++++++++++++++++ tests/kuery/test_parser.py | 61 ++++++++++++++ tests/test_all_rules.py | 145 +++++++++++++++++++++++++++++++++ tests/test_mappings.py | 69 ++++++++++++++++ tests/test_packages.py | 149 ++++++++++++++++++++++++++++++++++ tests/test_toml_formatter.py | 74 +++++++++++++++++ tests/test_utils.py | 102 +++++++++++++++++++++++ 12 files changed, 1017 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/kuery/__init__.py create mode 100644 tests/kuery/test_eql2kql.py create mode 100644 tests/kuery/test_evaluator.py create mode 100644 tests/kuery/test_kql2eql.py create mode 100644 tests/kuery/test_lint.py create mode 100644 tests/kuery/test_parser.py create mode 100644 tests/test_all_rules.py create mode 100644 tests/test_mappings.py create mode 100644 tests/test_packages.py create mode 100644 tests/test_toml_formatter.py create mode 100644 tests/test_utils.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..4c602303f --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,65 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +"""Detection Rules tests.""" +import glob +import json +import os + +from detection_rules.utils import combine_sources + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +DATA_DIR = os.path.join(CURRENT_DIR, 'data') +TP_DIR = os.path.join(DATA_DIR, 'true_positives') +FP_DIR = os.path.join(DATA_DIR, 'false_positives') + + +def get_fp_dirs(): + """Get a list of fp dir names.""" + return glob.glob(os.path.join(FP_DIR, '*')) + + +def get_fp_data_files(): + """get FP data files by fp dir name.""" + data = {} + for fp_dir in get_fp_dirs(): + fp_dir_name = os.path.basename(fp_dir) + relative_dir_name = os.path.join('false_positives', fp_dir_name) + data[fp_dir_name] = combine_sources(*get_data_files(relative_dir_name).values()) + + return data + + +def get_data_files_list(*folder, ext='jsonl', recursive=False): + """Get TP or FP file list.""" + folder = os.path.sep.join(folder) + data_dir = [DATA_DIR, folder] + if recursive: + data_dir.append('**') + + data_dir.append('*.{}'.format(ext)) + return glob.glob(os.path.join(*data_dir), recursive=recursive) + + +def get_data_files(*folder, ext='jsonl', recursive=False): + """Get data from data files.""" + data_files = {} + for data_file in get_data_files_list(*folder, ext=ext, recursive=recursive): + with open(data_file, 'r') as f: + file_name = os.path.splitext(os.path.basename(data_file))[0] + + if ext == 'jsonl': + data = f.readlines() + data_files[file_name] = [json.loads(d) for d in data] + else: + data_files[file_name] = json.load(f) + + return data_files + + +def get_data_file(*folder): + file = os.path.join(DATA_DIR, os.path.sep.join(folder)) + if os.path.exists(file): + with open(file, 'r') as f: + return json.load(f) diff --git a/tests/kuery/__init__.py b/tests/kuery/__init__.py new file mode 100644 index 000000000..12d34f0e9 --- /dev/null +++ b/tests/kuery/__init__.py @@ -0,0 +1,5 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +"""KQL unit tests.""" diff --git a/tests/kuery/test_eql2kql.py b/tests/kuery/test_eql2kql.py new file mode 100644 index 000000000..c2c5eb560 --- /dev/null +++ b/tests/kuery/test_eql2kql.py @@ -0,0 +1,52 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import unittest +import kql + + +class TestEql2Kql(unittest.TestCase): + + def validate(self, kql_source, eql_source): + self.assertEqual(kql_source, str(kql.from_eql(eql_source))) + + def test_field_equals(self): + self.validate("field:value", "field == 'value'") + self.validate("field:-1", "field == -1") + self.validate("field:1.1", "field == 1.1") + self.validate("field:true", "field == true") + self.validate("field:false", "field == false") + self.validate("field:*", "field != null") + self.validate("not field:*", "field == null") + + def test_field_inequality(self): + self.validate("field < value", "field < 'value'") + self.validate("field > -1", "field > -1") + self.validate("field <= 1.1", "field <= 1.1") + self.validate("field >= 0", "field >= 0") + + def test_or_query(self): + self.validate("field:value or field2:value2", "field == 'value' or field2 == 'value2'") + + def test_and_query(self): + self.validate("field:value and field2:value2", "field == 'value' and field2 == 'value2'") + + def test_not_query(self): + self.validate("not field:value", "field != 'value'") + self.validate("not (field:value and field2:value2)", "not (field = 'value' and field2 = 'value2')") + + def test_boolean_precedence(self): + self.validate("a:1 or b:2 and c:3", "a == 1 or (b == 2 and c == 3)") + self.validate("a:1 and (b:2 or c:3)", "a == 1 and (b == 2 or c == 3)") + self.validate("a:1 or not b:2 and c:3", "a == 1 or (b != 2 and c == 3)") + + def test_list_of_values(self): + self.validate("a:(0 or 1 or 2 or 3)", "a in (0,1,2,3)") + self.validate("a:(0 or 3 or 1 and 2)", "a == 0 or a == 1 and a == 2 or a == 3") + self.validate("a:(0 or 1 and 2 or 3 and 4)", "a == 0 or a == 1 and a == 2 or (a == 3 and a == 4)") + + def test_ip_checks(self): + self.validate("dest:192.168.255.255", "dest == '192.168.255.255'") + self.validate("dest:192.168.0.0/16", "cidrMatch(dest, '192.168.0.0/16')") + self.validate("dest:192.168.0.0/16", "cidrMatch(dest, '192.168.0.0/16')") diff --git a/tests/kuery/test_evaluator.py b/tests/kuery/test_evaluator.py new file mode 100644 index 000000000..a5bef9aa5 --- /dev/null +++ b/tests/kuery/test_evaluator.py @@ -0,0 +1,122 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import unittest + +import kql + + +class EvaluatorTests(unittest.TestCase): + + document = { + "number": 1, + "boolean": True, + "ip": "192.168.16.3", + "string": "hello world", + + "string_list": ["hello world", "example"], + "number_list": [1, 2, 3], + "boolean_list": [True, False], + "structured": [ + { + "a": [ + {"b": 1} + ] + } + ], + } + + def evaluate(self, source_text, document=None): + if document is None: + document = self.document + + evaluator = kql.get_evaluator(source_text, optimize=False) + return evaluator(document) + + def test_single_value(self): + self.assertTrue(self.evaluate('number:1')) + self.assertTrue(self.evaluate('number:"1"')) + self.assertTrue(self.evaluate('boolean:true')) + self.assertTrue(self.evaluate('string:"hello world"')) + + self.assertFalse(self.evaluate('number:0')) + self.assertFalse(self.evaluate('boolean:false')) + self.assertFalse(self.evaluate('string:"missing"')) + + def test_list_value(self): + self.assertTrue(self.evaluate('number_list:1')) + self.assertTrue(self.evaluate('number_list:2')) + self.assertTrue(self.evaluate('number_list:3')) + + self.assertTrue(self.evaluate('boolean_list:true')) + self.assertTrue(self.evaluate('boolean_list:false')) + + self.assertTrue(self.evaluate('string_list:"hello world"')) + self.assertTrue(self.evaluate('string_list:example')) + + self.assertFalse(self.evaluate('number_list:4')) + self.assertFalse(self.evaluate('string_list:"missing"')) + + def test_and_values(self): + self.assertTrue(self.evaluate('number_list:(1 and 2)')) + self.assertTrue(self.evaluate('boolean_list:(false and true)')) + self.assertFalse(self.evaluate('string:("missing" and "hello world")')) + + self.assertFalse(self.evaluate('number:(0 and 1)')) + self.assertFalse(self.evaluate('boolean:(false and true)')) + + def test_not_value(self): + self.assertTrue(self.evaluate('number_list:1')) + self.assertFalse(self.evaluate('not number_list:1')) + self.assertFalse(self.evaluate('number_list:(not 1)')) + + def test_or_values(self): + self.assertTrue(self.evaluate('number:(0 or 1)')) + self.assertTrue(self.evaluate('number:(1 or 2)')) + self.assertTrue(self.evaluate('boolean:(false or true)')) + self.assertTrue(self.evaluate('string:("missing" or "hello world")')) + + self.assertFalse(self.evaluate('number:(0 or 3)')) + + def test_and_expr(self): + self.assertTrue(self.evaluate('number:1 and boolean:true')) + + self.assertFalse(self.evaluate('number:1 and boolean:false')) + + def test_or_expr(self): + self.assertTrue(self.evaluate('number:1 or boolean:false')) + self.assertFalse(self.evaluate('number:0 or boolean:false')) + + def test_range(self): + self.assertTrue(self.evaluate('number < 2')) + self.assertFalse(self.evaluate('number > 2')) + + def test_cidr_match(self): + self.assertTrue(self.evaluate('ip:192.168.0.0/16')) + + self.assertFalse(self.evaluate('ip:10.0.0.0/8')) + + def test_quoted_wildcard(self): + self.assertFalse(self.evaluate('string:"*"')) + + def test_wildcard(self): + self.assertTrue(self.evaluate('string:hello*')) + self.assertTrue(self.evaluate('string:*world')) + self.assertFalse(self.evaluate('string:foobar*')) + + def test_field_exists(self): + self.assertTrue(self.evaluate('number:*')) + self.assertTrue(self.evaluate('boolean:*')) + self.assertTrue(self.evaluate('ip:*')) + self.assertTrue(self.evaluate('string:*')) + self.assertTrue(self.evaluate('string_list:*')) + self.assertTrue(self.evaluate('number_list:*')) + self.assertTrue(self.evaluate('boolean_list:*')) + + self.assertFalse(self.evaluate('a:*')) + + def test_flattening(self): + self.assertTrue(self.evaluate("structured.a.b:*")) + self.assertTrue(self.evaluate("structured.a.b:1")) + self.assertFalse(self.evaluate("structured.a.b:2")) diff --git a/tests/kuery/test_kql2eql.py b/tests/kuery/test_kql2eql.py new file mode 100644 index 000000000..94ab81e1c --- /dev/null +++ b/tests/kuery/test_kql2eql.py @@ -0,0 +1,94 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import unittest +import eql + +import kql + + +class TestKql2Eql(unittest.TestCase): + + def validate(self, kql_source, eql_source, schema=None): + self.assertEqual(kql.to_eql(kql_source, schema=schema), eql.parse_expression(eql_source)) + + def test_field_equals(self): + self.validate("field:value", "field == 'value'") + self.validate("field:-1", "field == -1") + self.validate("field:1.0", "field == 1.0") + self.validate("field:true", "field == true") + self.validate("field:false", "field == false") + self.validate("not field:*", "field == null") + self.validate("field:*", "field != null") + + def test_field_inequality(self): + self.validate("field < value", "field < 'value'") + self.validate("field > -1", "field > -1") + self.validate("field <= 1.0", "field <= 1.0") + self.validate("field >= 0", "field >= 0") + + def test_or_query(self): + self.validate("field:value or field2:value2", "field == 'value' or field2 == 'value2'") + + def test_and_query(self): + self.validate("field:value and field2:value2", "field == 'value' and field2 == 'value2'") + + def test_nested_query(self): + with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"): + kql.to_eql("field:{outer:1 and middle:{inner:2}}") + + def test_not_query(self): + self.validate("not field:value", "field != 'value'") + self.validate("not (field:value and field2:value2)", "not (field = 'value' and field2 = 'value2')") + + def test_boolean_precedence(self): + self.validate("a:1 or (b:2 and c:3)", "a == 1 or (b == 2 and c == 3)") + self.validate("a:1 or b:2 and c:3", "a == 1 or (b == 2 and c == 3)") + self.validate("a:1 or not b:2 and c:3", "a == 1 or (b != 2 and c == 3)") + + def test_list_of_values(self): + self.validate("a:(0 or 1 or 2 or 3)", "a in (0,1,2,3)") + self.validate("a:(0 or 1 and 2 or 3)", "a == 0 or a == 1 and a == 2 or a == 3") + self.validate("a:(0 or 1 and 2 or (3 and 4))", "a == 0 or a == 1 and a == 2 or (a == 3 and a == 4)") + + def test_lone_value(self): + for value in ["1", "-1.4", "true", "\"string test\""]: + with self.assertRaisesRegex(kql.KqlParseError, "Value not tied to field"): + kql.to_eql(value) + + def test_schema(self): + schema = { + "top": "nested", + "top.keyword": "keyword", + "top.text": "text", + "top.middle": "nested", + "top.middle.bool": "boolean", + "top.numL": "long", + "top.numF": "long", + "dest": "ip", + } + + self.validate("top.numF : 1", "top.numF == 1", schema=schema) + self.validate("top.numF : \"1\"", "top.numF == 1", schema=schema) + self.validate("top.keyword : 1", "top.keyword == '1'", schema=schema) + self.validate("top.text : \"hello\"", "top.text == 'hello'", schema=schema) + self.validate("top.keyword : \"hello\"", "top.keyword == 'hello'", schema=schema) + self.validate("top.text : 1 ", "top.text == '1'", schema=schema) + self.validate("dest:192.168.255.255", "dest == '192.168.255.255'", schema=schema) + self.validate("dest:192.168.0.0/16", "cidrMatch(dest, '192.168.0.0/16')", schema=schema) + self.validate("dest:\"192.168.0.0/16\"", "cidrMatch(dest, '192.168.0.0/16')", schema=schema) + + with self.assertRaisesRegex(kql.KqlParseError, r"Value doesn't match top.middle's type: nested"): + kql.to_eql("top.middle : 1", schema=schema) + + with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"): + kql.to_eql("top:{keyword : 1}", schema=schema) + + with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"): + kql.to_eql("top:{middle:{bool: true}}", schema=schema) + + invalid_ips = ["192.168.0.256", "192.168.0.256/33", "1", "\"1\""] + for ip in invalid_ips: + with self.assertRaisesRegex(kql.KqlParseError, r"Value doesn't match dest's type: ip"): + kql.to_eql("dest:{ip}".format(ip=ip), schema=schema) diff --git a/tests/kuery/test_lint.py b/tests/kuery/test_lint.py new file mode 100644 index 000000000..a4e43ebd7 --- /dev/null +++ b/tests/kuery/test_lint.py @@ -0,0 +1,79 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import unittest +import kql + + +class LintTests(unittest.TestCase): + + def validate(self, source, linted, *args): + self.assertEqual(kql.lint(source), linted, *args) + + def test_lint_field(self): + self.validate("a : b", "a:b") + self.validate("\"a\": b", "a:b") + self.validate("a : \"b\"", "a:b") + self.validate("a : (b)", "a:b") + self.validate("a:1.234", "a:1.234") + self.validate("a:\"1.234\"", "a:1.234") + + def test_upper_tokens(self): + queries = [ + "a:b AND c:d", + "a:b OR c:d", + "NOT a:b", + "a:(b OR c)", + "a:(b AND c)", + "a:(NOT b)", + ] + + for q in queries: + with self.assertRaises(kql.KqlParseError): + kql.parse(q) + + def test_lint_precedence(self): + self.validate("a:b or (c:d and e:f)", "a:b or c:d and e:f") + self.validate("(a:b and (c:d or e:f))", "a:b and (c:d or e:f)") + + def test_extract_not(self): + self.validate("a:(not b)", "not a:b") + + def test_merge_fields(self): + self.validate("a:b or a:c", "a:(b or c)") + self.validate("a:b or a:(c or d)", "a:(b or c or d)") + self.validate("a:b or a:(c or d) or a:e", "a:(b or c or d or e)") + + self.validate("a:b or a:(c and d) or x:y or a:e", "a:(b or e or c and d) or x:y", "Failed to left-align values") + self.validate("a:b and a:(c and d) or x:y or a:e", "a:(e or b and c and d) or x:y") + + def test_and_not(self): + self.validate("a:b and not a:c", "a:(b and not c)") + + def test_not_demorgans(self): + self.validate("not a:b and not a:c and not a:d", "not a:(b or c or d)") + self.validate("not a:b or not a:c or not a:d", "not a:(b and c and d)") + self.validate("a:(not b and not c and not d)", "not a:(b or c or d)") + self.validate("a:(not b or not c or not d)", "not a:(b and c and d)") + + def test_not_or(self): + self.validate("not (a:1 or a:2)", "not a:(1 or 2)") + + def test_mixed_demorgans(self): + self.validate("a:(b and not c and not d)", "a:(b and not (c or d))") + self.validate("a:(b or not c or not d or not e)", "a:(b or not (c and d and e))") + self.validate("a:((b or not c or not d) and e)", "a:(e and (b or not (c and d)))") + + def test_double_negate(self): + self.validate("not (not a:b)", "a:b") + self.validate("a:(not (not b))", "a:b") + self.validate("not (a:(not b))", "a:b") + self.validate("not (not (a:b or c:d))", "a:b or c:d") + self.validate("not (not (a:(not b) or c:(not d)))", "not a:b or not c:d") + + def test_ip(self): + self.validate("a:ff02\\:\\:fb", "a:\"ff02::fb\"") + + def test_compound(self): + self.validate("a:1 and b:2 and not (c:3 or c:4)", "a:1 and b:2 and not c:(3 or 4)") diff --git a/tests/kuery/test_parser.py b/tests/kuery/test_parser.py new file mode 100644 index 000000000..a7de4f548 --- /dev/null +++ b/tests/kuery/test_parser.py @@ -0,0 +1,61 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import unittest +import kql +from kql.ast import ( + Field, + FieldComparison, + String, + Number, + Exists, +) + + +class ParserTests(unittest.TestCase): + + def validate(self, source, tree, *args, **kwargs): + kwargs.setdefault("optimize", False) + self.assertEqual(kql.parse(source, *args, **kwargs), tree) + + def test_keyword(self): + schema = { + "a.text": "text", + "a.keyword": "keyword", + "b": "long", + } + + self.validate('a.text:hello', FieldComparison(Field("a.text"), String("hello")), schema=schema) + self.validate('a.keyword:hello', FieldComparison(Field("a.keyword"), String("hello")), schema=schema) + + self.validate('a.text:"hello"', FieldComparison(Field("a.text"), String("hello")), schema=schema) + self.validate('a.keyword:"hello"', FieldComparison(Field("a.keyword"), String("hello")), schema=schema) + + self.validate('a.text:1', FieldComparison(Field("a.text"), String("1")), schema=schema) + self.validate('a.keyword:1', FieldComparison(Field("a.keyword"), String("1")), schema=schema) + + self.validate('a.text:"1"', FieldComparison(Field("a.text"), String("1")), schema=schema) + self.validate('a.keyword:"1"', FieldComparison(Field("a.keyword"), String("1")), schema=schema) + + def test_conversion(self): + schema = {"num": "long", "text": "text"} + + self.validate('num:1', FieldComparison(Field("num"), Number(1)), schema=schema) + self.validate('num:"1"', FieldComparison(Field("num"), Number(1)), schema=schema) + + self.validate('text:1', FieldComparison(Field("text"), String("1")), schema=schema) + self.validate('text:"1"', FieldComparison(Field("text"), String("1")), schema=schema) + + def test_list_equals(self): + self.assertEqual(kql.parse("a:(1 or 2)", optimize=False), kql.parse("a:(2 or 1)", optimize=False)) + + def test_number_exists(self): + self.assertEqual(kql.parse("foo:*", schema={"foo": "long"}), FieldComparison(Field("foo"), Exists())) + + def test_number_wildcard_fail(self): + with self.assertRaises(kql.KqlParseError): + kql.parse("foo:*wc", schema={"foo": "long"}) + + with self.assertRaises(kql.KqlParseError): + kql.parse("foo:wc*", schema={"foo": "long"}) diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py new file mode 100644 index 000000000..cfa65a7be --- /dev/null +++ b/tests/test_all_rules.py @@ -0,0 +1,145 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +"""Test that all rules have valid metadata and syntax.""" +import json +import os +import re +import sys +import unittest + +import jsonschema +import kql +import toml +import pytoml +from rta import get_ttp_names + +from detection_rules import rule_loader +from detection_rules.utils import load_etc_dump +from detection_rules.rule import Rule + + +class TestValidRules(unittest.TestCase): + """Test that all detection rules load properly without duplicates.""" + + def test_schema_and_dupes(self): + """Ensure that every rule matches the schema and there are no duplicates.""" + rule_files = rule_loader.load_rule_files() + self.assertGreaterEqual(len(rule_files), 1, 'No rules were loaded from rules directory!') + + def test_all_rule_files(self): + """Ensure that every rule file can be loaded and validate against schema.""" + rules = [] + + for file_name, contents in rule_loader.load_rule_files().items(): + try: + rule = Rule(file_name, contents) + rules.append(rule) + + except (pytoml.TomlError, toml.TomlDecodeError) as e: + print("TOML error when parsing rule file \"{}\"".format(os.path.basename(file_name)), file=sys.stderr) + raise e + + except jsonschema.ValidationError as e: + print("Schema error when parsing rule file \"{}\"".format(os.path.basename(file_name)), file=sys.stderr) + raise e + + def test_rule_loading(self): + """Ensure that all rule queries have ecs version.""" + rule_loader.load_rules().values() + + def test_file_names(self): + """Test that the file names meet the requirement.""" + file_pattern = rule_loader.FILE_PATTERN + + self.assertIsNone(re.match(file_pattern, 'NotValidRuleFile.toml'), + 'Incorrect pattern for verifying rule names: {}'.format(file_pattern)) + self.assertIsNone(re.match(file_pattern, 'still_not_a_valid_file_name.not_json'), + 'Incorrect pattern for verifying rule names: {}'.format(file_pattern)) + + for rule_file in rule_loader.load_rule_files().keys(): + self.assertIsNotNone(re.match(file_pattern, os.path.basename(rule_file)), + 'Invalid file name for {}'.format(rule_file)) + + def test_all_rules_as_rule_schema(self): + """Ensure that every rule file validates against the rule schema.""" + for file_name, contents in rule_loader.load_rule_files().items(): + rule = Rule(file_name, contents) + rule.validate(as_rule=True) + + def test_all_rules_tuned(self): + """Ensure that every rule file validates against the rule schema.""" + for file_name, contents in rule_loader.load_rule_files().items(): + rule = Rule(file_name, contents, tune=True) + rule.validate(as_rule=True) + + def test_all_rule_queries_optimized(self): + """Ensure that every rule query is in optimized form.""" + for file_name, contents in rule_loader.load_rule_files().items(): + rule = Rule(file_name, contents) + + if rule.query and rule.contents['language'] == 'kuery': + tree = kql.parse(rule.query, optimize=False) + optimized = tree.optimize(recursive=True) + err_message = '\nQuery not optimized for rule: {} - {}\nExpected: {}\nActual: {}'.format( + rule.name, rule.id, optimized, rule.query) + self.assertEqual(tree, optimized, err_message) + + def test_ecs_version_in_query(self): + """Ensure that all rule queries have ecs version.""" + # rule_loader.reset() + # rules = list(rule_loader.load_rules().values()) + # + # for rule in rules: + # ecs_ver = rule.metadata.get('ecs_version') + # if ecs_ver: + # self.assertTrue('ecs.version:{}'.format(ecs_ver) in rule.query, + # 'ecs_version specified but missing from query') + + def test_rules_lint_integrity(self): + """Verify that linting is not compromising integrity of a rule.""" + '''def validate(source, linted, *args): + self.assertEqual(kql.lint(source), linted, *args) + + rules = rule_loader.load_rules().values() + + for rule in rules: + try: + linted = eql2kql.convert(kql2eql.parse(rule.query).render()) + validate(rule.query, linted, 'Linting improperly modified the query from: \n\t{} \nto \n\t{}'.format( + rule.query, linted)) + except Exception as e: + raise Exception('{} - {}:\n{}'.format(rule.name, rule.query, e))''' + + def test_no_unrequired_defaults(self): + """Test that values that are not required in the schema are not set with default values.""" + rules_with_hits = {} + + for file_name, contents in rule_loader.load_rule_files().items(): + rule = Rule(file_name, contents) + default_matches = rule_loader.find_unneeded_defaults(rule) + + if default_matches: + rules_with_hits['{} - {}'.format(rule.name, rule.id)] = default_matches + + error_msg = 'The following rules have unnecessary default values set: \n{}'.format( + json.dumps(rules_with_hits, indent=2)) + self.assertDictEqual(rules_with_hits, {}, error_msg) + + @rule_loader.mock_loader + def test_production_rules_have_rta(self): + """Ensure that all production rules have RTAs.""" + mappings = load_etc_dump('rule-mapping.yml') + + ttp_names = get_ttp_names() + + for rule in rule_loader.get_production_rules(): + if rule.type == 'query' and rule.id in mappings: + matching_rta = mappings[rule.id].get('rta_name') + + self.assertIsNotNone(matching_rta, "Rule {} ({}) does not have RTAs".format(rule.name, rule.id)) + + rta_name, ext = os.path.splitext(matching_rta) + if rta_name not in ttp_names: + self.fail("{} ({}) references unknown RTA: {}".format(rule.name, rule.id, rta_name)) diff --git a/tests/test_mappings.py b/tests/test_mappings.py new file mode 100644 index 000000000..e319b7ff9 --- /dev/null +++ b/tests/test_mappings.py @@ -0,0 +1,69 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +"""Test that all rules appropriately match against expected data sets.""" +import copy +import unittest +import warnings + +from . import get_data_files, get_fp_data_files +from detection_rules import rule_loader +from detection_rules.utils import combine_sources, evaluate, load_etc_dump + + +class TestMappings(unittest.TestCase): + """Test that all rules appropriately match against expected data sets.""" + + FP_FILES = get_fp_data_files() + RULES = rule_loader.load_rules().values() + + def evaluate(self, documents, rule, expected, msg): + """KQL engine to evaluate.""" + filtered = evaluate(rule, documents) + self.assertEqual(expected, len(filtered), msg) + return filtered + + def test_true_positives(self): + """Test that expected results return against true positives.""" + mismatched_ecs = [] + mappings = load_etc_dump('rule-mapping.yml') + + for rule in rule_loader.get_production_rules(): + if rule.type == 'query' and rule.contents['language'] == 'kuery': + if rule.id not in mappings: + continue + + mapping = mappings[rule.id] + expected = mapping['count'] + sources = mapping.get('sources') + rta_file = mapping['rta_name'] + + # ensure sources is defined and not empty; schema allows it to not be set since 'pending' bypasses + self.assertTrue(sources, 'No sources defined for: {} - {} '.format(rule.id, rule.name)) + msg = 'Expected TP results did not match for: {} - {}'.format(rule.id, rule.name) + + data_files = [get_data_files('true_positives', rta_file).get(s) for s in sources] + data_file = combine_sources(*data_files) + results = self.evaluate(data_file, rule, expected, msg) + + ecs_versions = set([r.get('ecs', {}).get('version') for r in results]) + rule_ecs = set(rule.metadata.get('ecs_version').copy()) + + if not ecs_versions & rule_ecs: + msg = '{} - {} ecs_versions ({}) not in source data versions ({})'.format( + rule.id, rule.name, ', '.join(rule_ecs), ', '.join(ecs_versions)) + mismatched_ecs.append(msg) + + if mismatched_ecs: + msg = 'Rules detected with source data from ecs versions not listed within the rule: \n{}'.format( + '\n'.join(mismatched_ecs)) + warnings.warn(msg) + + def test_false_positives(self): + """Test that expected results return against false positives.""" + for rule in rule_loader.get_production_rules(): + if rule.type == 'query' and rule.contents['language'] == 'kuery': + for fp_name, merged_data in get_fp_data_files().items(): + msg = 'Unexpected FP match for: {} - {}, against: {}'.format(rule.id, rule.name, fp_name) + self.evaluate(copy.deepcopy(merged_data), rule, 0, msg) diff --git a/tests/test_packages.py b/tests/test_packages.py new file mode 100644 index 000000000..b118207fe --- /dev/null +++ b/tests/test_packages.py @@ -0,0 +1,149 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +"""Test that the packages are built correctly.""" +import unittest +import uuid +import yaml + +from detection_rules import rule_loader +from detection_rules.packaging import Package, PACKAGE_FILE + + +class TestPackages(unittest.TestCase): + """Test package building and saving.""" + + @staticmethod + def get_test_rule(version=1, count=1): + def get_rule_contents(): + contents = { + "author": ["Elastic"], + "description": "test description", + "language": "kuery", + "license": "Elastic License", + "name": "test rule", + "query": "process.name:test.query", + "risk_score": 21, + "rule_id": str(uuid.uuid4()), + "severity": "low", + "type": "query" + } + return contents + + rules = [rule_loader.Rule('test.toml', get_rule_contents()) for i in range(count)] + version_info = { + rule.id: { + 'rule_name': rule.name, + 'sha256': rule.get_hash(), + 'version': version + } for rule in rules + } + + return rules, version_info + + def test_package_loader_production_config(self): + """Test that packages are loading correctly.""" + + def test_package_loader_default_configs(self): + """Test configs in etc/packages.yml.""" + with open(PACKAGE_FILE) as f: + configs = yaml.safe_load(f)['package'] + + package = Package.from_config(configs) + for rule in package.rules: + rule.contents.pop('version') + rule.validate(as_rule=True) + + @rule_loader.mock_loader + def test_package_summary(self): + """Test the generation of the package summary.""" + rules = list(rule_loader.load_rules().values()) + package = Package(rules, 'test-package') + changed_rules, new_rules = package.bump_versions(save_changes=False) + package.generate_summary(changed_rules, new_rules) + + def test_versioning_diffs(self): + """Test that versioning is detecting diffs as expected.""" + rules, version_info = self.get_test_rule() + package = Package(rules, 'test', current_versions=version_info) + + # test versioning doesn't falsely detect changes + changed_rules, new_rules = package.changed_rules, package.new_rules + + self.assertEqual(0, len(changed_rules), 'Package version bumping is improperly detecting changed rules') + self.assertEqual(0, len(new_rules), 'Package version bumping is improperly detecting new rules') + self.assertEqual(1, package.rules[0].contents['version'], 'Package version bumping unexpectedly') + + # test versioning detects a new rule + package.rules[0].contents.pop('version') + changed_rules, new_rules = package.bump_versions(current_versions={}) + + self.assertEqual(0, len(changed_rules), 'Package version bumping is improperly detecting changed rules') + self.assertEqual(1, len(new_rules), 'Package version bumping is not detecting new rules') + self.assertEqual(1, package.rules[0].contents['version'], + 'Package version bumping not setting version to 1 for new rules') + + # test versioning detects a hash changes + package.rules[0].contents.pop('version') + package.rules[0].contents['query'] = 'process.name:changed.test.query' + changed_rules, new_rules = package.bump_versions(current_versions=version_info) + + self.assertEqual(1, len(changed_rules), 'Package version bumping is not detecting changed rules') + self.assertEqual(0, len(new_rules), 'Package version bumping is improperly detecting new rules') + self.assertEqual(2, package.rules[0].contents['version'], 'Package version not bumping on changes') + + @rule_loader.mock_loader + def test_rule_versioning(self): + """Test that all rules are properly versioned and tracked""" + self.maxDiff = None + rules = rule_loader.load_rules().values() + original_hashes = [] + post_bump_hashes = [] + + # test that no rules have versions defined + for rule in rules: + self.assertIsNone(rule.contents.get('version'), '{} - {}: explicitly sets a version in the rule file') + original_hashes.append(rule.get_hash()) + + package = Package(rules, 'test-package') + + # test that all rules have versions defined + # package.bump_versions(save_changes=False) + for rule in package.rules: + self.assertGreaterEqual(rule.contents.get('version'), 1, '{} - {}: version is not being set in package') + + # test that rules validate with version + for rule in package.rules: + rule.validate(versioned=True) + rule.contents.pop('version') + post_bump_hashes.append(rule.get_hash()) + + # test that no hashes changed as a result of the version bumps + self.assertListEqual(original_hashes, post_bump_hashes, 'Version bumping modified the hash of a rule') + + def test_version_filter(self): + """Test that version filtering is working as expected.""" + msg = 'Package version filter failing' + + rules, version_info = self.get_test_rule(version=1, count=3) + package = Package(rules, 'test', current_versions=version_info, min_version=2) + self.assertEqual(0, len(package.rules), msg) + + rules, version_info = self.get_test_rule(version=5, count=3) + package = Package(rules, 'test', current_versions=version_info, max_version=2) + self.assertEqual(0, len(package.rules), msg) + + rules, version_info = self.get_test_rule(version=2, count=3) + package = Package(rules, 'test', current_versions=version_info, min_version=1, max_version=3) + self.assertEqual(3, len(package.rules), msg) + + rules, version_info = self.get_test_rule(version=1, count=3) + + version = 1 + for rule_id, vinfo in version_info.items(): + vinfo['version'] = version + version += 1 + + package = Package(rules, 'test', current_versions=version_info, min_version=2, max_version=2) + self.assertEqual(1, len(package.rules), msg) diff --git a/tests/test_toml_formatter.py b/tests/test_toml_formatter.py new file mode 100644 index 000000000..6b4d815f5 --- /dev/null +++ b/tests/test_toml_formatter.py @@ -0,0 +1,74 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +import copy +import json +import os +import pytoml +import unittest +from detection_rules.utils import get_etc_path +from detection_rules import rule_loader +from detection_rules.rule_formatter import nested_normalize, toml_write + + +tmp_file = 'tmp_file.toml' + + +class TestRuleTomlFormatter(unittest.TestCase): + """Test that the cutom toml formatting is not compromising the integrity of the data.""" + with open(get_etc_path('test_toml.json'), 'r') as f: + test_data = json.load(f) + + def compare_formatted(self, data, callback=None): + """Compare formatted vs expected.""" + try: + toml_write(copy.deepcopy(data), tmp_file) + + with open(tmp_file, 'r') as f: + formatted_contents = pytoml.load(f) + + # callbacks such as nested normalize leave in line breaks, so this must be manually done + query = data.get('rule', {}).get('query') + if query: + data['rule']['query'] = query.strip() + + original = json.dumps(copy.deepcopy(data), sort_keys=True) + + if callback: + formatted_contents = callback(formatted_contents) + + # callbacks such as nested normalize leave in line breaks, so this must be manually done + query = formatted_contents.get('rule', {}).get('query') + if query: + formatted_contents['rule']['query'] = query.strip() + + formatted = json.dumps(formatted_contents, sort_keys=True) + self.assertEqual(original, formatted, 'Formatting may be modifying contents') + + finally: + os.remove(tmp_file) + + def compare_test_data(self, test_dicts, callback=None): + """Compare test data against expected.""" + for data in test_dicts: + self.compare_formatted(data, callback=callback) + + def test_normalization(self): + """Test that normalization does not change the rule contents.""" + self.compare_test_data([nested_normalize(self.test_data[0])], callback=nested_normalize) + + def test_formatter_rule(self): + """Test that formatter and encoder do not change the rule contents.""" + self.compare_test_data([self.test_data[0]]) + + def test_formatter_deep(self): + """Test that the data remains unchanged from formatting.""" + self.compare_test_data(self.test_data[1:]) + + def test_format_of_all_rules(self): + """Test all rules.""" + rules = rule_loader.load_rules().values() + + for rule in rules: + self.compare_formatted(rule.rule_format(formatted_query=False), callback=nested_normalize) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..312b6b0bc --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,102 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. + +"""Test util time functions.""" +import random +import time +import unittest + +from detection_rules.utils import normalize_timing_and_sort, cached +from detection_rules.eswrap import Events +from detection_rules.ecs import get_kql_schema + + +class TestTimeUtils(unittest.TestCase): + """Test util time functions.""" + + @staticmethod + def get_events(timestamp_field='@timestamp'): + """Get test data.""" + date_formats = { + 'epoch_millis': lambda x: int(round(time.time(), 3) + x) * 1000, + 'epoch_second': lambda x: round(time.time()) + x, + 'unix_micros': lambda x: time.time() + x, + 'unix_millis': lambda x: round(time.time(), 3) + x, + 'strict_date_optional_time': lambda x: '2020-05-13T04:36:' + str(15 + x) + '.394Z' + } + + def _get_data(func): + data = [ + {timestamp_field: func(0), 'foo': 'bar', 'id': 1}, + {timestamp_field: func(1), 'foo': 'bar', 'id': 2}, + {timestamp_field: func(2), 'foo': 'bar', 'id': 3}, + {timestamp_field: func(3), 'foo': 'bar', 'id': 4}, + {timestamp_field: func(4), 'foo': 'bar', 'id': 5}, + {timestamp_field: func(5), 'foo': 'bar', 'id': 6} + ] + random.shuffle(data) + return data + + return {fmt: _get_data(func) for fmt, func in date_formats.items()} + + def assert_sort(self, normalized_events, date_format): + """Assert normalize and sort.""" + order = [e['id'] for e in normalized_events] + self.assertListEqual([1, 2, 3, 4, 5, 6], order, 'Sorting failed for date_format: {}'.format(date_format)) + + def test_time_normalize(self): + """Test normalize_timing_from_date_format.""" + events_data = self.get_events() + for date_format, events in events_data.items(): + normalized = normalize_timing_and_sort(events) + self.assert_sort(normalized, date_format) + + def test_event_class_normalization(self): + """Test that events are normalized properly within Events.""" + events_data = self.get_events() + for date_format, events in events_data.items(): + normalized = Events('_', {'winlogbeat': events}) + self.assert_sort(normalized.events['winlogbeat'], date_format) + + def test_schema_multifields(self): + """Tests that schemas are loading multifields correctly.""" + schema = get_kql_schema(version="1.4.0") + self.assertEqual(schema.get("process.name"), "keyword") + self.assertEqual(schema.get("process.name.text"), "text") + + def test_caching(self): + """Test that caching is working.""" + counter = 0 + + @cached + def increment(*args, **kwargs): + nonlocal counter + + counter += 1 + return counter + + self.assertEqual(increment(), 1) + self.assertEqual(increment(), 1) + self.assertEqual(increment(), 1) + + self.assertEqual(increment(["hello", "world"]), 2) + self.assertEqual(increment(["hello", "world"]), 2) + self.assertEqual(increment(["hello", "world"]), 2) + + self.assertEqual(increment(), 1) + self.assertEqual(increment(["hello", "world"]), 2) + + self.assertEqual(increment({"hello": [("world", )]}), 3) + self.assertEqual(increment({"hello": [("world", )]}), 3) + + self.assertEqual(increment(), 1) + self.assertEqual(increment(["hello", "world"]), 2) + self.assertEqual(increment({"hello": [("world", )]}), 3) + + increment.clear() + self.assertEqual(increment({"hello": [("world", )]}), 4) + self.assertEqual(increment(["hello", "world"]), 5) + self.assertEqual(increment(), 6) + self.assertEqual(increment(None), 7) + self.assertEqual(increment(1), 8)