diff --git a/detection_rules/eswrap.py b/detection_rules/eswrap.py index fa056c6b0..77968723e 100644 --- a/detection_rules/eswrap.py +++ b/detection_rules/eswrap.py @@ -73,10 +73,10 @@ class RtaEvents: return events @staticmethod - def _get_dump_dir(rta_name=None, host_id=None): + def _get_dump_dir(rta_name=None, host_id=None, host_os_family=None): """Prepare and get the dump path.""" - if rta_name: - dump_dir = get_path('unit_tests', 'data', 'true_positives', rta_name) + if rta_name and host_os_family: + dump_dir = get_path('unit_tests', 'data', 'true_positives', rta_name, host_os_family) os.makedirs(dump_dir, exist_ok=True) return dump_dir else: @@ -113,10 +113,20 @@ class RtaEvents: """Save collected events.""" assert self.events, 'Nothing to save. Run Collector.run() method first or verify logging' - dump_dir = dump_dir or self._get_dump_dir(rta_name=rta_name, host_id=host_id) + host_os_family = None + for key in self.events.keys(): + if self.events.get(key, {})[0].get('host', {}).get('id') == host_id: + host_os_family = self.events.get(key, {})[0].get('host', {}).get('os').get('family') + break + if not host_os_family: + click.echo('Unable to determine host.os.family for host_id: {}'.format(host_id)) + host_os_family = click.prompt("Please enter the host.os.family for this host_id", + type=click.Choice(["windows", "macos", "linux"]), default="windows") + + dump_dir = dump_dir or self._get_dump_dir(rta_name=rta_name, host_id=host_id, host_os_family=host_os_family) for source, events in self.events.items(): - path = os.path.join(dump_dir, source + '.jsonl') + path = os.path.join(dump_dir, source + '.ndjson') with open(path, 'w') as f: f.writelines([json.dumps(e, sort_keys=True) + '\n' for e in events]) click.echo('{} events saved to: {}'.format(len(events), path)) diff --git a/detection_rules/utils.py b/detection_rules/utils.py index 749d4ad2a..ee3d01b56 100644 --- a/detection_rules/utils.py +++ b/detection_rules/utils.py @@ -191,8 +191,35 @@ def unzip_to_dict(zipped: zipfile.ZipFile, load_json=True) -> Dict[str, Union[di def event_sort(events, timestamp='@timestamp', date_format='%Y-%m-%dT%H:%M:%S.%f%z', asc=True): """Sort events from elasticsearch by timestamp.""" + def round_microseconds(t: str) -> str: + """Rounds the microseconds part of a timestamp string to 6 decimal places.""" + + if not t: + # Return early if the timestamp string is empty + return t + + parts = t.split('.') + if len(parts) == 2: + # Remove trailing "Z" from microseconds part + micro_seconds = parts[1].rstrip("Z") + + if len(micro_seconds) > 6: + # If the microseconds part has more than 6 digits + # Convert the microseconds part to a float and round to 6 decimal places + rounded_micro_seconds = round(float(f"0.{micro_seconds}"), 6) + + # Format the rounded value to always have 6 decimal places + # Reconstruct the timestamp string with the rounded microseconds part + formatted_micro_seconds = f'{rounded_micro_seconds:0.6f}'.split(".")[-1] + t = f"{parts[0]}.{formatted_micro_seconds}Z" + + return t + def _event_sort(event): - t = event[timestamp] + """Calculates the sort key for an event.""" + t = round_microseconds(event[timestamp]) + + # Return the timestamp in seconds, adjusted for microseconds and then scaled to milliseconds return (time.mktime(time.strptime(t, date_format)) + int(t.split('.')[-1][:-1]) / 1000) * 1000 return sorted(events, key=_event_sort, reverse=not asc) diff --git a/tests/__init__.py b/tests/__init__.py index 360d78839..43f48cf21 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -32,7 +32,7 @@ def get_fp_data_files(): return data -def get_data_files_list(*folder, ext='jsonl', recursive=False): +def get_data_files_list(*folder, ext='ndjson', recursive=False): """Get TP or FP file list.""" folder = os.path.sep.join(folder) data_dir = [DATA_DIR, folder] @@ -43,14 +43,14 @@ def get_data_files_list(*folder, ext='jsonl', recursive=False): return glob.glob(os.path.join(*data_dir), recursive=recursive) -def get_data_files(*folder, ext='jsonl', recursive=False): +def get_data_files(*folder, ext='ndjson', recursive=False): """Get data from data files.""" data_files = {} for data_file in get_data_files_list(*folder, ext=ext, recursive=recursive): with open(data_file, 'r') as f: file_name = os.path.splitext(os.path.basename(data_file))[0] - if ext == 'jsonl': + if ext in ('.ndjson', '.jsonl'): data = f.readlines() data_files[file_name] = [json.loads(d) for d in data] else: