[FR] Add host family to data path (#2839)
* add rounding logic * cleaned up event_sort * fix linting * Added host_family to ndjson file path * linting fix * Added ability to manually supply host_os_family * fixed linting * Update detection_rules/utils.py Co-authored-by: Mika Ayenson <Mikaayenson@users.noreply.github.com> * Update detection_rules/utils.py Co-authored-by: Mika Ayenson <Mikaayenson@users.noreply.github.com> * linting updates --------- Co-authored-by: Mika Ayenson <Mikaayenson@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
1e404cde34
commit
450e84ffa2
@@ -73,10 +73,10 @@ class RtaEvents:
|
||||
return events
|
||||
|
||||
@staticmethod
|
||||
def _get_dump_dir(rta_name=None, host_id=None):
|
||||
def _get_dump_dir(rta_name=None, host_id=None, host_os_family=None):
|
||||
"""Prepare and get the dump path."""
|
||||
if rta_name:
|
||||
dump_dir = get_path('unit_tests', 'data', 'true_positives', rta_name)
|
||||
if rta_name and host_os_family:
|
||||
dump_dir = get_path('unit_tests', 'data', 'true_positives', rta_name, host_os_family)
|
||||
os.makedirs(dump_dir, exist_ok=True)
|
||||
return dump_dir
|
||||
else:
|
||||
@@ -113,10 +113,20 @@ class RtaEvents:
|
||||
"""Save collected events."""
|
||||
assert self.events, 'Nothing to save. Run Collector.run() method first or verify logging'
|
||||
|
||||
dump_dir = dump_dir or self._get_dump_dir(rta_name=rta_name, host_id=host_id)
|
||||
host_os_family = None
|
||||
for key in self.events.keys():
|
||||
if self.events.get(key, {})[0].get('host', {}).get('id') == host_id:
|
||||
host_os_family = self.events.get(key, {})[0].get('host', {}).get('os').get('family')
|
||||
break
|
||||
if not host_os_family:
|
||||
click.echo('Unable to determine host.os.family for host_id: {}'.format(host_id))
|
||||
host_os_family = click.prompt("Please enter the host.os.family for this host_id",
|
||||
type=click.Choice(["windows", "macos", "linux"]), default="windows")
|
||||
|
||||
dump_dir = dump_dir or self._get_dump_dir(rta_name=rta_name, host_id=host_id, host_os_family=host_os_family)
|
||||
|
||||
for source, events in self.events.items():
|
||||
path = os.path.join(dump_dir, source + '.jsonl')
|
||||
path = os.path.join(dump_dir, source + '.ndjson')
|
||||
with open(path, 'w') as f:
|
||||
f.writelines([json.dumps(e, sort_keys=True) + '\n' for e in events])
|
||||
click.echo('{} events saved to: {}'.format(len(events), path))
|
||||
|
||||
@@ -191,8 +191,35 @@ def unzip_to_dict(zipped: zipfile.ZipFile, load_json=True) -> Dict[str, Union[di
|
||||
def event_sort(events, timestamp='@timestamp', date_format='%Y-%m-%dT%H:%M:%S.%f%z', asc=True):
|
||||
"""Sort events from elasticsearch by timestamp."""
|
||||
|
||||
def round_microseconds(t: str) -> str:
|
||||
"""Rounds the microseconds part of a timestamp string to 6 decimal places."""
|
||||
|
||||
if not t:
|
||||
# Return early if the timestamp string is empty
|
||||
return t
|
||||
|
||||
parts = t.split('.')
|
||||
if len(parts) == 2:
|
||||
# Remove trailing "Z" from microseconds part
|
||||
micro_seconds = parts[1].rstrip("Z")
|
||||
|
||||
if len(micro_seconds) > 6:
|
||||
# If the microseconds part has more than 6 digits
|
||||
# Convert the microseconds part to a float and round to 6 decimal places
|
||||
rounded_micro_seconds = round(float(f"0.{micro_seconds}"), 6)
|
||||
|
||||
# Format the rounded value to always have 6 decimal places
|
||||
# Reconstruct the timestamp string with the rounded microseconds part
|
||||
formatted_micro_seconds = f'{rounded_micro_seconds:0.6f}'.split(".")[-1]
|
||||
t = f"{parts[0]}.{formatted_micro_seconds}Z"
|
||||
|
||||
return t
|
||||
|
||||
def _event_sort(event):
|
||||
t = event[timestamp]
|
||||
"""Calculates the sort key for an event."""
|
||||
t = round_microseconds(event[timestamp])
|
||||
|
||||
# Return the timestamp in seconds, adjusted for microseconds and then scaled to milliseconds
|
||||
return (time.mktime(time.strptime(t, date_format)) + int(t.split('.')[-1][:-1]) / 1000) * 1000
|
||||
|
||||
return sorted(events, key=_event_sort, reverse=not asc)
|
||||
|
||||
+3
-3
@@ -32,7 +32,7 @@ def get_fp_data_files():
|
||||
return data
|
||||
|
||||
|
||||
def get_data_files_list(*folder, ext='jsonl', recursive=False):
|
||||
def get_data_files_list(*folder, ext='ndjson', recursive=False):
|
||||
"""Get TP or FP file list."""
|
||||
folder = os.path.sep.join(folder)
|
||||
data_dir = [DATA_DIR, folder]
|
||||
@@ -43,14 +43,14 @@ def get_data_files_list(*folder, ext='jsonl', recursive=False):
|
||||
return glob.glob(os.path.join(*data_dir), recursive=recursive)
|
||||
|
||||
|
||||
def get_data_files(*folder, ext='jsonl', recursive=False):
|
||||
def get_data_files(*folder, ext='ndjson', recursive=False):
|
||||
"""Get data from data files."""
|
||||
data_files = {}
|
||||
for data_file in get_data_files_list(*folder, ext=ext, recursive=recursive):
|
||||
with open(data_file, 'r') as f:
|
||||
file_name = os.path.splitext(os.path.basename(data_file))[0]
|
||||
|
||||
if ext == 'jsonl':
|
||||
if ext in ('.ndjson', '.jsonl'):
|
||||
data = f.readlines()
|
||||
data_files[file_name] = [json.loads(d) for d in data]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user