diff --git a/detection_rules/rule_validators.py b/detection_rules/rule_validators.py index 9dcfdb468..dbaf62361 100644 --- a/detection_rules/rule_validators.py +++ b/detection_rules/rule_validators.py @@ -93,68 +93,112 @@ class KQLValidator(QueryValidator): print(err_trailer) return exc - def validate_integration(self, data: QueryRuleData, meta: RuleMeta, package_integrations: List[dict]) -> Union[ - KQL_ERROR_TYPES, None, TypeError]: + def validate_integration( + self, data: QueryRuleData, meta: RuleMeta, package_integrations: List[dict] + ) -> Union[KQL_ERROR_TYPES, None, TypeError]: """Validate the query, called from the parent which contains [metadata] information.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": - # syntax only, which is done via self.ast return error_fields = {} - current_stack_version = "" - combined_schema = {} - for integration_schema_data in get_integration_schema_data(data, meta, package_integrations): - ecs_version = integration_schema_data['ecs_version'] - integration = integration_schema_data['integration'] - package = integration_schema_data['package'] - package_version = integration_schema_data['package_version'] - integration_schema = integration_schema_data['schema'] - stack_version = integration_schema_data['stack_version'] + package_schemas = {} - if stack_version != current_stack_version: - # reset the combined schema for each stack version - current_stack_version = stack_version - combined_schema = {} + # Initialize package_schemas with a nested structure + for integration_data in package_integrations: + package = integration_data["package"] + integration = integration_data["integration"] + if integration: + package_schemas.setdefault(package, {}).setdefault(integration, {}) + else: + package_schemas.setdefault(package, {}) - # add non-ecs-schema fields for edge cases not added to the integration + # Process each integration schema + for integration_schema_data in get_integration_schema_data( + data, meta, package_integrations + ): + package, integration = ( + integration_schema_data["package"], + integration_schema_data["integration"], + ) + integration_schema = integration_schema_data["schema"] + + # Add non-ecs-schema fields for index_name in data.index: integration_schema.update(**ecs.flatten(ecs.get_index_schema(index_name))) - # add endpoint schema fields for multi-line fields + # Add endpoint schema fields for multi-line fields integration_schema.update(**ecs.flatten(ecs.get_endpoint_schemas())) - combined_schema.update(**integration_schema) + if integration: + package_schemas[package][integration] = integration_schema + else: + package_schemas[package] = integration_schema + # Validate the query against the schema try: - # validate the query against the integration fields with the package version kql.parse(self.query, schema=integration_schema) except kql.KqlParseError as exc: if exc.error_msg == "Unknown field": - field = extract_error_field(exc) - trailer = (f"\n\tTry adding event.module or event.dataset to specify integration module\n\t" - f"Will check against integrations {meta.integration} combined.\n\t" - f"{package=}, {integration=}, {package_version=}, " - f"{stack_version=}, {ecs_version=}" - ) - error_fields[field] = {"error": exc, "trailer": trailer} + field = extract_error_field(self.query, exc) + trailer = ( + f"\n\tTry adding event.module or event.dataset to specify integration module\n\t" + f"Will check against integrations {meta.integration} combined.\n\t" + f"{package=}, {integration=}, {integration_schema_data['package_version']=}, " + f"{integration_schema_data['stack_version']=}, " + f"{integration_schema_data['ecs_version']=}" + ) + error_fields[field] = { + "error": exc, + "trailer": trailer, + "package": package, + "integration": integration, + } if data.get("notify", False): - print(f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}") + print( + f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}" + ) else: - return kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source, - len(exc.caret.lstrip()), trailer=trailer) + return kql.KqlParseError( + exc.error_msg, + exc.line, + exc.column, + exc.source, + len(exc.caret.lstrip()), + exc.trailer, + ) - # don't error on fields that are in another integration schema - for field in list(error_fields.keys()): - if field in combined_schema: - del error_fields[field] + # Check error fields against schemas of different packages or different integrations + for field, error_data in list(error_fields.items()): + error_package, error_integration = ( + error_data["package"], + error_data["integration"], + ) + for package, integrations_or_schema in package_schemas.items(): + if error_integration is None: + # Compare against the schema directly if there's no integration + if error_package != package and field in integrations_or_schema: + del error_fields[field] + break + else: + # Compare against integration schemas + for integration, schema in integrations_or_schema.items(): + check_alt_schema = ( + error_package != package or # noqa: W504 + (error_package == package and error_integration != integration) + ) + if check_alt_schema and field in schema: + del error_fields[field] - # raise the first error + # Raise the first error if error_fields: - _, data = next(iter(error_fields.items())) - exc = data["error"] - trailer = data["trailer"] - - return kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source, - len(exc.caret.lstrip()), trailer=trailer) + _, error_data = next(iter(error_fields.items())) + return kql.KqlParseError( + error_data["error"].error_msg, + error_data["error"].line, + error_data["error"].column, + error_data["error"].source, + len(error_data["error"].caret.lstrip()), + error_data["trailer"], + ) class EQLValidator(QueryValidator): @@ -235,28 +279,37 @@ class EQLValidator(QueryValidator): if exc: raise exc - def validate_integration(self, data: QueryRuleData, meta: RuleMeta, package_integrations: List[dict]) -> Union[ - EQL_ERROR_TYPES, None, ValueError]: + def validate_integration(self, data: QueryRuleData, meta: RuleMeta, + package_integrations: List[dict]) -> Union[EQL_ERROR_TYPES, None, ValueError]: """Validate an EQL query while checking TOMLRule against integration schemas.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": # syntax only, which is done via self.ast return error_fields = {} - current_stack_version = "" - combined_schema = {} - for integration_schema_data in get_integration_schema_data(data, meta, package_integrations): - ecs_version = integration_schema_data['ecs_version'] - integration = integration_schema_data['integration'] - package = integration_schema_data['package'] - package_version = integration_schema_data['package_version'] - integration_schema = integration_schema_data['schema'] - stack_version = integration_schema_data['stack_version'] + package_schemas = {} - if stack_version != current_stack_version: - # reset the combined schema for each stack version - current_stack_version = stack_version - combined_schema = {} + # Initialize package_schemas with a nested structure + for integration_data in package_integrations: + package = integration_data["package"] + integration = integration_data["integration"] + if integration: + package_schemas.setdefault(package, {}).setdefault(integration, {}) + else: + package_schemas.setdefault(package, {}) + + # Process each integration schema + for integration_schema_data in get_integration_schema_data( + data, meta, package_integrations + ): + ecs_version = integration_schema_data["ecs_version"] + package, integration = ( + integration_schema_data["package"], + integration_schema_data["integration"], + ) + package_version = integration_schema_data["package_version"] + integration_schema = integration_schema_data["schema"] + stack_version = integration_schema_data["stack_version"] # add non-ecs-schema fields for edge cases not added to the integration for index_name in data.index: @@ -264,34 +317,65 @@ class EQLValidator(QueryValidator): # add endpoint schema fields for multi-line fields integration_schema.update(**ecs.flatten(ecs.get_endpoint_schemas())) - combined_schema.update(**integration_schema) + package_schemas[package].update(**integration_schema) eql_schema = ecs.KqlSchema2Eql(integration_schema) - err_trailer = f'stack: {stack_version}, integration: {integration},' \ - f'ecs: {ecs_version}, package: {package}, package_version: {package_version}' + err_trailer = ( + f"stack: {stack_version}, integration: {integration}," + f"ecs: {ecs_version}, package: {package}, package_version: {package_version}" + ) - exc = self.validate_query_with_schema(data=data, schema=eql_schema, err_trailer=err_trailer, - min_stack_version=meta.min_stack_version) + # Validate the query against the schema + exc = self.validate_query_with_schema( + data=data, + schema=eql_schema, + err_trailer=err_trailer, + min_stack_version=meta.min_stack_version, + ) if isinstance(exc, eql.EqlParseError): message = exc.error_msg if message == "Unknown field" or "Field not recognized" in message: - field = extract_error_field(exc) - trailer = (f"\n\tTry adding event.module or event.dataset to specify integration module\n\t" - f"Will check against integrations {meta.integration} combined.\n\t" - f"{package=}, {integration=}, {package_version=}, " - f"{stack_version=}, {ecs_version=}" - ) - error_fields[field] = {"error": exc, "trailer": trailer} + field = extract_error_field(self.query, exc) + trailer = ( + f"\n\tTry adding event.module or event.dataset to specify integration module\n\t" + f"Will check against integrations {meta.integration} combined.\n\t" + f"{package=}, {integration=}, {package_version=}, " + f"{stack_version=}, {ecs_version=}" + ) + error_fields[field] = { + "error": exc, + "trailer": trailer, + "package": package, + "integration": integration, + } if data.get("notify", False): - print(f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}") + print( + f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}" + ) else: return exc - # don't error on fields that are in another integration schema - for field in list(error_fields.keys()): - if field in combined_schema: - del error_fields[field] + # Check error fields against schemas of different packages or different integrations + for field, error_data in list(error_fields.items()): + error_package, error_integration = ( + error_data["package"], + error_data["integration"], + ) + for package, integrations_or_schema in package_schemas.items(): + if error_integration is None: + # Compare against the schema directly if there's no integration + if error_package != package and field in integrations_or_schema: + del error_fields[field] + else: + # Compare against integration schemas + for integration, schema in integrations_or_schema.items(): + check_alt_schema = ( + error_package != package or # noqa: W504 + (error_package == package and error_integration != integration) + ) + if check_alt_schema and field in schema: + del error_fields[field] # raise the first error if error_fields: @@ -373,9 +457,9 @@ class ESQLValidator(QueryValidator): pass -def extract_error_field(exc: Union[eql.EqlParseError, kql.KqlParseError]) -> Optional[str]: +def extract_error_field(source: str, exc: Union[eql.EqlParseError, kql.KqlParseError]) -> Optional[str]: """Extract the field name from an EQL or KQL parse error.""" - lines = exc.source.splitlines() + lines = source.splitlines() mod = -1 if exc.line == len(lines) else 0 line = lines[exc.line + mod] start = exc.column