# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one # or more contributor license agreements. Licensed under the Elastic License # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. from __future__ import unicode_literals, print_function import binascii import contextlib import functools import getpass import inspect import os import platform import re import shutil import socket import subprocess import sys import tempfile import threading import time from pathlib import Path from typing import Iterable, Optional, Union from http.server import HTTPServer, SimpleHTTPRequestHandler long_t = type(1 << 63) HOSTNAME = socket.gethostname() LOCAL_IP = None def get_ip() -> str: global LOCAL_IP, HOSTNAME if LOCAL_IP is None: try: LOCAL_IP = socket.gethostbyname(HOSTNAME) except socket.gaierror: LOCAL_IP = "127.0.0.1" return LOCAL_IP def get_winreg(): try: import _winreg as winreg except ImportError: import winreg return winreg # Multi-OS Support WINDOWS = "windows" MACOS = "macos" LINUX = "linux" if sys.platform == "darwin": CURRENT_OS = MACOS elif sys.platform.startswith("win"): CURRENT_OS = WINDOWS else: CURRENT_OS = LINUX if CURRENT_OS == WINDOWS: CMD_PATH = os.environ.get("COMSPEC") POWERSHELL_PATH = "C:\\WINDOWS\\System32\\WindowsPowerShell\\v1.0\\powershell.exe" else: CMD_PATH = "/bin/sh" POWERSHELL_PATH = None BASE_DIR = os.path.dirname(os.path.abspath(__file__)) ALL_IP = "0.0.0.0" IP_REGEX = r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}" CALLBACK_REGEX = r"https?://" + IP_REGEX + r":\d+" USER_NAME = getpass.getuser().lower() SUCCESS = 0 PYTHON_ERROR = 1 # Python does this internally, so we don't want to overwrite it GENERAL_ERROR = 2 MISSING_DEPENDENCIES = 3 MISSING_PSEXEC = 4 ACCESS_DENIED = 5 UNSUPPORTED_RTA = 6 MISSING_REMOTE_HOST = 7 # Amount of seconds a command should take at a minimum. # This can allow for arbitrary slow down of scripts MIN_EXECUTION_TIME = 0 MAX_HOSTS = 64 # Useful constants HKLM = "hklm" HKCU = "hkcu" HKU = "hku" HKCR = "hkcr" SZ = "sz" EXPAND_SZ = "expand_sz" MULTI_SZ = "multi_sz" DWORD = "dword" OS_MAPPING = {WINDOWS: [], MACOS: [], LINUX: []} def requires_os(*os_list: str): if len(os_list) == 1 and isinstance(os_list[0], (list, tuple)): os_list = os_list[0] def decorator(f): # Register this function with the support os mapping for os_type in os_list: OS_MAPPING[os_type].append(f.__module__.split(".")[-1]) @functools.wraps(f) def decorated(*args, **kwargs): if CURRENT_OS not in os_list: filename = os.path.relpath(inspect.getsourcefile(f)) func_name = f.__name__ log(f"Unsupported OS for {filename}:{func_name}(). Expected {'/'.join(os_list)}", "!") return UNSUPPORTED_RTA return f(*args, **kwargs) return decorated return decorator def check_dependencies(*paths: str) -> bool: missing = [] for path in paths: if not os.path.exists(path): log("Missing dependency %s" % path, "!") missing.append(path) return len(missing) == 0 def dependencies(*paths: str): missing = [] for path in paths: if not os.path.exists(path): missing.append(path) def decorator(f): @functools.wraps(f) def decorated(*args, **kwargs): if len(missing): log("Missing dependencies for %s:%s()" % (f.func_code.co_filename, f.func_code.co_name), "!") for dep in missing: print(" - %s" % os.path.relpath(dep, BASE_DIR)) return MISSING_DEPENDENCIES return f(*args, **kwargs) return decorated return decorator def pause(): time.sleep(0.5) def get_path(*path: str) -> str: return str(Path(BASE_DIR).joinpath(*path)) @contextlib.contextmanager def temporary_file(contents, file_name=None): handle, close = temporary_file_helper(contents, file_name) try: yield handle finally: close() def temporary_file_helper(contents, file_name=None): if not (file_name and os.path.isabs(file_name)): file_name = os.path.join(tempfile.gettempdir(), file_name or f"temp{hash(contents):d}") with open(file_name, "wb" if isinstance(contents, bytes) else "w") as f: f.write(contents) f = open(file_name, "rb" if isinstance(contents, bytes) else "r") def close(): f.close() os.remove(file_name) return f, close def execute( command: Iterable, hide_log=False, mute=False, timeout: int = 30, wait=True, kill=False, drop=False, stdin: Optional[Union[bytes, str]] = None, shell=False, **kwargs, ): """Execute a process and get the output.""" command_string = command close = None if isinstance(command, (list, tuple)): command_string = subprocess.list2cmdline(command) if shell: command = command_string else: sys.stderr.write("Deprecation warning! Switch arguments to a list for common.execute()\n\n") if not hide_log: print("%s @ %s > %s" % (USER_NAME, HOSTNAME, command_string)) if isinstance(stdin, (bytes, str)): stdin, close = temporary_file_helper(stdin) stdout = subprocess.PIPE stderr = subprocess.STDOUT if drop or kill: devnull = open(os.devnull, "w") stdout = devnull stderr = devnull start = time.time() p = subprocess.Popen(command, stdin=stdin or subprocess.PIPE, stdout=stdout, stderr=stderr, shell=shell, **kwargs) if kill: delta = 0.5 # Try waiting for the process to die for _ in range(int(timeout / delta) + 1): time.sleep(delta) if p.poll() is not None: return log("Killing process", str(p.pid)) try: p.kill() time.sleep(0.5) except OSError: pass elif wait: output = "" if not stdin: try: p.stdin.write(os.linesep.encode("ascii")) except IOError: # this pipe randomly breaks when executing certain non-zero exit commands on linux pass while p.poll() is None: line = p.stdout.readline().decode("ascii", "ignore") if line: output += line if not (hide_log or mute): print(line.rstrip()) output += p.stdout.read().decode("ascii", "ignore") output = output.strip() # Add artificial sleep to slow down command lines end = time.time() run_time = end - start if run_time < MIN_EXECUTION_TIME: time.sleep(MIN_EXECUTION_TIME - run_time) if not (hide_log or mute): if p.returncode != 0: print("exit code = %d" % p.returncode) print("") if close: close() return p.returncode, output else: if close: close() return p def log(message, log_type="+"): print("[%s] %s" % (log_type, message)) def copy_file(source, target): log("Copying %s -> %s" % (source, target)) shutil.copy(source, target) def create_macos_masquerade(masquerade: str): if platform.processor() == "arm": name = "com.apple.ditto_and_spawn_arm" else: name = "com.apple.ditto_and_spawn_intel" source = get_path("bin", name) copy_file(source, masquerade) def link_file(source, target): log("Linking %s -> %s" % (source, target)) execute(["ln", "-s", source, target]) def remove_file(path): if os.path.exists(path): log("Removing %s" % path, log_type="-") # Try three times to remove the file for _ in range(3): try: os.remove(path) except OSError: time.sleep(0.25) else: return def remove_directory(path): if os.path.exists(path): if os.path.isdir(path): log(f"Removing directory {path:s}", log_type="-") shutil.rmtree(path) else: remove_file(path) def is_64bit(): return os.environ.get("PROCESSOR_ARCHITECTURE", "") in ("x64", "AMD64") def remove_files(*paths): for path in paths: remove_file(path) def clear_web_cache(): log("Clearing temporary files", log_type="-") execute(["RunDll32.exe", "InetCpl.cpl,", "ClearMyTracksByProcess", "8"], hide_log=True) time.sleep(1) def serve_web(ip=None, port=None, directory=BASE_DIR): handler = SimpleHTTPRequestHandler ip = ip or get_ip() if port is not None: server = HTTPServer((ip, port), handler) else: # Otherwise, try to find a port for port in range(8000, 9000): try: server = HTTPServer((ip, port), handler) break except socket.error: pass def server_thread(): log(f"Starting web server on http://{ip}:{port:d} for directory {directory}") os.chdir(directory) server.serve_forever() # Start this thread in the background thread = threading.Thread(target=server_thread) thread.setDaemon(True) thread.start() time.sleep(0.5) return server, ip, port def patch_file(source_file, old_bytes, new_bytes, target_file=None): target_file = target_file or target_file log( "Patching bytes %s [%s] --> %s [%s]" % (source_file, binascii.b2a_hex(old_bytes), target_file, binascii.b2a_hex(new_bytes)) ) with open(source_file, "rb") as f: contents = f.read() patched = contents.replace(old_bytes, new_bytes) with open(target_file, "wb") as f: f.write(patched) def patch_regex(source_file, regex, new_bytes, target_file=None): regex = regex.encode("ascii") new_bytes = new_bytes.encode("ascii") target_file = target_file or source_file log("Patching by regex %s --> %s" % (source_file, target_file)) with open(source_file, "rb") as f: contents = f.read() matches = re.findall(regex, contents) log("Changing %s -> %s" % (", ".join("{}".format(m) for m in matches), new_bytes)) contents = re.sub(regex, new_bytes, contents) with open(target_file, "wb") as f: f.write(contents) def wchar(s): return s.encode("utf-16le") def find_remote_host(): log("Searching for remote Windows hosts") _, stdout = execute("net view", hide_log=True) hosts = re.findall(r"\\\\([\w\d\._-]+)", stdout) # _, current_user = execute("whoami", hide_log=True) pending = {} log("Discovery %d possible hosts" % len(hosts)) for name in hosts[:MAX_HOSTS]: name = name.lower() if name.split(".")[0] == HOSTNAME.split(".")[0]: continue # log("Checking if %s has remote admin permissions to %s" % (current_user, name)) dev_null = open(os.devnull, "w") p = subprocess.Popen("sc.exe \\\\%s query" % name, stdout=dev_null, stderr=dev_null, stdin=subprocess.PIPE) pending[name] = p if len(pending) > 0: # See which ones return first with a success code, and use that host for _ in range(20): for hostname, pending_process in sorted(pending.items()): if pending_process.poll() is None: pending_process.stdin.write(os.linesep) if pending_process.returncode == 0: # Now need to get the IP address ip = get_ipv4_address(hostname) if ip is not None: log("Using remote host %s (%s)" % (ip, hostname)) return ip pending.pop(hostname) time.sleep(0.5) log("Unable to find a remote host to pivot to. Using local host %s" % HOSTNAME, log_type="!") return get_ip() def get_ipv4_address(hostname): if re.match(IP_REGEX, hostname): return hostname code, output = execute(["ping", hostname, "-4", "-n", 1], hide_log=True) if code != 0: return None addresses = re.findall(IP_REGEX, output) if len(addresses) == 0: return None return addresses[0] def find_writeable_directory(base_dir): for root, dirs, files in os.walk(base_dir): for d in dirs: subdir = os.path.join(base_dir, d) try: test_file = os.path.join(subdir, "test_file") f = open(test_file, "w") f.close() os.remove(test_file) return subdir except IOError: pass def check_system(): return USER_NAME == "system" or USER_NAME.endswith("$") PS_EXEC = get_path("bin", "PsExec.exe") def run_system(arguments=None): if check_system(): return None if arguments is None: arguments = [sys.executable, os.path.abspath(sys.argv[0])] + sys.argv[1:] log("Attempting to elevate to SYSTEM using PsExec") if not os.path.exists(PS_EXEC): log("PsExec not found", log_type="-") return MISSING_PSEXEC p = subprocess.Popen([PS_EXEC, "-w", os.getcwd(), "-accepteula", "-s"] + arguments) p.wait() code = p.returncode if code == ACCESS_DENIED: log("Failed to escalate to SYSTEM", "!") return code def write_reg( hive: str, key: str, value: str, data: Union[str, int], data_type: Union[str, int, list], restore=True, pause=False, append=False, ) -> None: with temporary_reg(hive, key, value, data, data_type, restore, pause, append): pass def read_reg(hive: str, key: str, value: str) -> (str, str): winreg = get_winreg() if isinstance(hive, str): hives = { "hklm": winreg.HKEY_LOCAL_MACHINE, "hkcu": winreg.HKEY_LOCAL_MACHINE, "hku": winreg.HKEY_USERS, "hkcr": winreg.HKEY_CLASSES_ROOT, } hive = hives[hive.lower()] try: hkey = winreg.CreateKey(hive, key.rstrip("\\")) old_data, old_type = winreg.QueryValueEx(hkey, value) except WindowsError as e: # check if the key already exists if e.errno != 2: raise return None, None return old_data, old_type @contextlib.contextmanager def temporary_reg( hive: str, key: str, value: str, data: Union[str, int], data_type: Union[str, int, list] = "sz", restore=True, pause=False, append=False, ) -> None: winreg = get_winreg() if isinstance(hive, str): hives = { "hklm": winreg.HKEY_LOCAL_MACHINE, "hkcu": winreg.HKEY_CURRENT_USER, "hku": winreg.HKEY_USERS, "hkcr": winreg.HKEY_CLASSES_ROOT, } hive = hives[hive.lower()] if isinstance(data_type, str): attr = "REG_" + data_type.upper() data_type = getattr(winreg, attr) if data_type is None: data_type = winreg.REG_SZ key = key.rstrip("\\") hkey = winreg.CreateKey(hive, key) exists = False old_data = None old_type = None if hkey: try: old_data, old_type = winreg.QueryValueEx(hkey, value) exists = True except WindowsError as e: # check if the key already exists exists = False old_data, old_type = None, None if e.errno != 2: raise if append and exists: # If appending to the existing REG_MULTI_SZ key, then append to the end if not isinstance(data, list): data = [data] if isinstance(old_data, list): data = old_data + data data_string = ",".join(data) if isinstance(data, list) else data log("Writing to registry %s\\%s -> %s" % (key, value, data_string)) winreg.SetValueEx(hkey, value, 0, data_type, data) stored, code = winreg.QueryValueEx(hkey, value) if data != stored: log("Wrote %s but retrieved %s" % (data, stored), log_type="-") # Allow code to execute within the context manager 'with' try: yield finally: if restore: time.sleep(0.5) if not exists: # If it didn't already exist, then delete it log("Deleting %s\\%s" % (key, value), log_type="-") winreg.DeleteValue(hkey, value) else: # Otherwise restore the value data_string = ",".join(old_data) if isinstance(old_data, list) else old_data log("Restoring registry %s\\%s -> %s" % (key, value, data_string), log_type="-") winreg.SetValueEx(hkey, value, 0, old_type, old_data) hkey.Close() print("") if pause: time.sleep(0.5) def enable_logon_auditing(host="localhost", verbose=True, sleep=2): """Enable logon auditing on local or remote system to enable 4624 and 4625 events.""" if verbose: log(f"Ensuring audit logging enabled on {host}") auditpol = "auditpol.exe /set /subcategory:Logon /failure:enable /success:enable" enable_logging = "Invoke-WmiMethod -ComputerName {} -Class Win32_process -Name create -ArgumentList '{}'".format( host, auditpol ) command = ["powershell", "-c", enable_logging] enable = execute(command) # additional time to allow auditing to process time.sleep(sleep) return enable def print_file(path): print(path) if not os.path.exists(path): print("--- NOT FOUND ----") else: print("-" * 16) with open(path, "r") as f: print(f.read().rstrip()) print("")