diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 209fc185b296..9bcd6d23d17c 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -145,6 +145,8 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") +parser.add_argument("--use-subprocess-workers", action="store_true", help="Execute each prompt in an isolated subprocess with complete GPU/ROCm context reset. Ensures clean state between jobs but adds startup overhead.") +parser.add_argument("--subprocess-timeout", type=int, default=600, help="Timeout in seconds for subprocess execution (default: 600, only used with --use-subprocess-workers).") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") class PerformanceFeature(enum.Enum): diff --git a/comfy/execution_core.py b/comfy/execution_core.py new file mode 100644 index 000000000000..845d2ac15d99 --- /dev/null +++ b/comfy/execution_core.py @@ -0,0 +1,145 @@ +"""Core execution logic shared between normal and subprocess execution modes.""" + +import logging +import time + +_active_worker = None + + +def create_worker(server_instance): + """Create worker backend. Returns NativeWorker or SubprocessWorker.""" + global _active_worker + from comfy.cli_args import args + + server = WorkerServer(server_instance) + + if args.use_subprocess_workers: + from comfy.worker_process import SubprocessWorker + worker = SubprocessWorker(server, timeout=args.subprocess_timeout) + else: + from comfy.worker_native import NativeWorker + worker = NativeWorker(server) + + _active_worker = worker + return worker + + +async def init_execution_environment(): + """Load nodes and custom nodes. Returns number of node types loaded.""" + import nodes + from comfy.cli_args import args + + await nodes.init_extra_nodes( + init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, + init_api_nodes=not args.disable_api_nodes + ) + return len(nodes.NODE_CLASS_MAPPINGS) + + +def setup_progress_hook(server_instance, interrupt_checker): + """Set up global progress hook. interrupt_checker must raise on interrupt.""" + import comfy.utils + from comfy_execution.progress import get_progress_state + from comfy_execution.utils import get_executing_context + + def hook(value, total, preview_image, prompt_id=None, node_id=None): + ctx = get_executing_context() + if ctx: + prompt_id = prompt_id or ctx.prompt_id + node_id = node_id or ctx.node_id + + interrupt_checker() + + prompt_id = prompt_id or server_instance.last_prompt_id + node_id = node_id or server_instance.last_node_id + + get_progress_state().update_progress(node_id, value, total, preview_image) + server_instance.send_sync("progress", {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}, server_instance.client_id) + + comfy.utils.set_progress_bar_global_hook(hook) + + +class WorkerServer: + """Protocol boundary: client_id, last_node_id, last_prompt_id, sockets_metadata, send_sync(), queue_updated()""" + + _WRITABLE = {'client_id', 'last_node_id', 'last_prompt_id'} + + def __init__(self, server): + object.__setattr__(self, '_server', server) + + def __setattr__(self, name, value): + if name in self._WRITABLE: + setattr(self._server, name, value) + else: + raise AttributeError(f"WorkerServer does not accept attribute '{name}'") + + @property + def client_id(self): + return self._server.client_id + + @property + def last_node_id(self): + return self._server.last_node_id + + @property + def last_prompt_id(self): + return self._server.last_prompt_id + + @property + def sockets_metadata(self): + return self._server.sockets_metadata + + def send_sync(self, event, data, sid=None): + self._server.send_sync(event, data, sid or self.client_id) + + def queue_updated(self): + self._server.queue_updated() + +def interrupt_processing(value=True): + _active_worker.interrupt(value) + + +def _strip_sensitive(prompt): + return prompt[:5] + prompt[6:] + + +def prompt_worker(q, worker): + """Main prompt execution loop.""" + import execution + + server = worker.server_instance + + while True: + queue_item = q.get(timeout=worker.get_gc_timeout()) + if queue_item is not None: + item, item_id = queue_item + start_time = time.perf_counter() + prompt_id = item[1] + server.last_prompt_id = prompt_id + + extra_data = {**item[3], **item[5]} + + result = worker.execute_prompt(item[2], prompt_id, extra_data, item[4], server=server) + worker.mark_needs_gc() + + q.task_done( + item_id, + result['history_result'], + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if result['success'] else 'error', + completed=result['success'], + messages=result['status_messages'] + ), + process_item=_strip_sensitive + ) + + if server.client_id is not None: + server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server.client_id) + + elapsed = time.perf_counter() - start_time + if elapsed > 600: + logging.info(f"Prompt executed in {time.strftime('%H:%M:%S', time.gmtime(elapsed))}") + else: + logging.info(f"Prompt executed in {elapsed:.2f} seconds") + + worker.handle_flags(q.get_flags()) diff --git a/comfy/worker_native.py b/comfy/worker_native.py new file mode 100644 index 000000000000..244b713deab7 --- /dev/null +++ b/comfy/worker_native.py @@ -0,0 +1,95 @@ +"""Native (in-process) worker for prompt execution.""" + +import time +import gc + + +class NativeWorker: + """Executes prompts in the same process as the server.""" + + def __init__(self, server_instance, interrupt_checker=None): + self.server_instance = server_instance + self.interrupt_checker = interrupt_checker + self.executor = None + self.last_gc_collect = 0 + self.need_gc = False + self.gc_collect_interval = 10.0 + + async def initialize(self): + """Load nodes and set up executor. Returns node count.""" + from execution import PromptExecutor, CacheType + from comfy.cli_args import args + from comfy.execution_core import init_execution_environment, setup_progress_hook + import comfy.model_management as mm + import hook_breaker_ac10a0 + + hook_breaker_ac10a0.save_functions() + try: + node_count = await init_execution_environment() + finally: + hook_breaker_ac10a0.restore_functions() + + interrupt_checker = self.interrupt_checker or mm.throw_exception_if_processing_interrupted + setup_progress_hook(self.server_instance, interrupt_checker=interrupt_checker) + + cache_type = CacheType.CLASSIC + if args.cache_lru > 0: + cache_type = CacheType.LRU + elif args.cache_ram > 0: + cache_type = CacheType.RAM_PRESSURE + elif args.cache_none: + cache_type = CacheType.NONE + + self.executor = PromptExecutor( + self.server_instance, + cache_type=cache_type, + cache_args={"lru": args.cache_lru, "ram": args.cache_ram} + ) + return node_count + + def execute_prompt(self, prompt, prompt_id, extra_data, execute_outputs, server=None): + self.executor.execute(prompt, prompt_id, extra_data, execute_outputs) + return { + 'success': self.executor.success, + 'history_result': self.executor.history_result, + 'status_messages': self.executor.status_messages, + 'prompt_id': prompt_id + } + + def handle_flags(self, flags): + import comfy.model_management as mm + import hook_breaker_ac10a0 + + free_memory = flags.get("free_memory", False) + + if flags.get("unload_models", free_memory): + mm.unload_all_models() + self.need_gc = True + self.last_gc_collect = 0 + + if free_memory: + if self.executor: + self.executor.reset() + self.need_gc = True + self.last_gc_collect = 0 + + if self.need_gc: + current_time = time.perf_counter() + if (current_time - self.last_gc_collect) > self.gc_collect_interval: + gc.collect() + mm.soft_empty_cache() + self.last_gc_collect = current_time + self.need_gc = False + hook_breaker_ac10a0.restore_functions() + + def interrupt(self, value=True): + import comfy.model_management + comfy.model_management.interrupt_current_processing(value) + + def mark_needs_gc(self): + self.need_gc = True + + def get_gc_timeout(self): + if self.need_gc: + return max(self.gc_collect_interval - (time.perf_counter() - self.last_gc_collect), 0.0) + return 1000.0 diff --git a/comfy/worker_process.py b/comfy/worker_process.py new file mode 100644 index 000000000000..de1b9cb9e616 --- /dev/null +++ b/comfy/worker_process.py @@ -0,0 +1,179 @@ +"""Subprocess worker for isolated prompt execution with complete GPU/ROCm reset.""" + +import logging +import multiprocessing as mp +import time +import traceback + +mp.set_start_method('spawn', force=True) + + +def _deserialize_preview(msg): + """Deserialize preview image from IPC transport.""" + if not (isinstance(msg['data'], dict) and msg['data'].get('_serialized')): + return msg + + from PIL import Image + from io import BytesIO + import base64 + + s = msg['data'] + pil_image = Image.open(BytesIO(base64.b64decode(s['image_bytes']))) + msg['data'] = ((s['image_type'], pil_image, s['max_size']), s['metadata']) + return msg + + +def _error_result(worker_id, prompt_id, error, tb=None): + return { + 'success': False, + 'error': error, + 'traceback': tb, + 'history_result': {}, + 'status_messages': [], + 'worker_id': worker_id, + 'prompt_id': prompt_id + } + + +def _kill_worker(worker, worker_id): + if not worker.is_alive(): + return + worker.terminate() + worker.join(timeout=2) + if worker.is_alive(): + logging.warning(f"Worker {worker_id} didn't terminate, killing") + worker.kill() + worker.join() + + +class SubprocessWorker: + """Executes each prompt in an isolated subprocess with fresh GPU context.""" + + def __init__(self, server_instance, timeout=600): + self.server_instance = server_instance + self.timeout = timeout + self.worker_counter = 0 + self.current_worker = None + self.interrupt_event = None + logging.info("SubprocessWorker created - each job will run in isolated process") + + async def initialize(self): + """Load node definitions for prompt validation. Returns node count.""" + from comfy.execution_core import init_execution_environment + return await init_execution_environment() + + def handle_flags(self, flags): + pass + + def mark_needs_gc(self): + pass + + def get_gc_timeout(self): + return 1000.0 + + def interrupt(self, value=True): + if not value: + return + if self.interrupt_event: + self.interrupt_event.set() + if self.current_worker and self.current_worker.is_alive(): + self.current_worker.join(timeout=2) + _kill_worker(self.current_worker, self.worker_counter) + self.current_worker = None + + def _relay_messages(self, message_queue, server): + """Relay queued messages to UI.""" + while not message_queue.empty(): + try: + msg = _deserialize_preview(message_queue.get_nowait()) + if server: + server.send_sync(msg['event'], msg['data'], msg['sid']) + except: + break + + def execute_prompt(self, prompt, prompt_id, extra_data={}, execute_outputs=[], server=None): + self.worker_counter += 1 + worker_id = self.worker_counter + + job_queue = mp.Queue() + result_queue = mp.Queue() + message_queue = mp.Queue() + self.interrupt_event = mp.Event() + + client_id = extra_data.get('client_id') + client_metadata = {} + if client_id and hasattr(server, 'sockets_metadata'): + client_metadata = server.sockets_metadata.get(client_id, {}) + + job_data = { + 'prompt': prompt, + 'prompt_id': prompt_id, + 'extra_data': extra_data, + 'execute_outputs': execute_outputs, + 'client_sockets_metadata': client_metadata + } + + from comfy.worker_process_child import worker_main + worker = mp.Process( + target=worker_main, + args=(job_queue, result_queue, message_queue, self.interrupt_event, worker_id), + name=f'ComfyUI-Worker-{worker_id}' + ) + + logging.info(f"Starting worker {worker_id} for prompt {prompt_id}") + self.current_worker = worker + worker.start() + job_queue.put(job_data) + + try: + start_time = time.time() + result = None + + while result is None: + if self.interrupt_event.is_set(): + logging.info(f"Worker {worker_id} interrupted") + if server: + server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server.client_id) + return _error_result(worker_id, prompt_id, 'Execution interrupted by user') + + if time.time() - start_time > self.timeout: + raise TimeoutError() + + self._relay_messages(message_queue, server) + + try: + result = result_queue.get(timeout=0.1) + except mp.queues.Empty: + pass + + self._relay_messages(message_queue, server) + + worker.join(timeout=5) + if worker.is_alive(): + _kill_worker(worker, worker_id) + + logging.info(f"Worker {worker_id} cleaned up (exit code: {worker.exitcode})") + self.current_worker = None + return result + + except TimeoutError: + error = f"Worker {worker_id} timed out after {self.timeout}s. Try --subprocess-timeout to increase." + logging.error(error) + _kill_worker(worker, worker_id) + self.current_worker = None + return _error_result(worker_id, prompt_id, error) + + except Exception as e: + error = f"Worker {worker_id} IPC error: {e}" + logging.error(f"{error}\n{traceback.format_exc()}") + _kill_worker(worker, worker_id) + self.current_worker = None + return _error_result(worker_id, prompt_id, error, traceback.format_exc()) + + finally: + for q in (job_queue, result_queue, message_queue): + q.close() + try: + q.join_thread() + except: + pass diff --git a/comfy/worker_process_child.py b/comfy/worker_process_child.py new file mode 100644 index 000000000000..e4990fd17182 --- /dev/null +++ b/comfy/worker_process_child.py @@ -0,0 +1,104 @@ +"""Subprocess worker child process entry point.""" + +import logging +import multiprocessing as mp +import traceback + + +class IPCMessageServer: + """IPC-based message server for subprocess workers.""" + + def __init__(self, message_queue, client_id=None, sockets_metadata=None): + self.message_queue = message_queue + self.client_id = client_id + self.last_node_id = None + self.last_prompt_id = None + self.sockets_metadata = sockets_metadata or {} + + def send_sync(self, event, data, sid=None): + from protocol import BinaryEventTypes + from io import BytesIO + import base64 + + if event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA and isinstance(data, tuple): + preview_image, metadata = data + image_type, pil_image, max_size = preview_image + + buffer = BytesIO() + pil_image.save(buffer, format=image_type) + + data = { + '_serialized': True, + 'image_type': image_type, + 'image_bytes': base64.b64encode(buffer.getvalue()).decode('utf-8'), + 'max_size': max_size, + 'metadata': metadata + } + + self.message_queue.put_nowait({'event': event, 'data': data, 'sid': sid}) + + def queue_updated(self): + pass + + +def worker_main(job_queue, result_queue, message_queue, interrupt_event, worker_id): + """Subprocess worker entry point - spawned fresh for each execution.""" + job_data = None + try: + logging.basicConfig(level=logging.INFO, format=f'[Worker-{worker_id}] %(levelname)s: %(message)s') + logging.info(f"Worker {worker_id} starting (PID: {mp.current_process().pid})") + + import asyncio + import comfy.model_management + from comfy.worker_native import NativeWorker + from comfy.execution_core import WorkerServer + + logging.info(f"Worker {worker_id} initialized. Device: {comfy.model_management.get_torch_device()}") + + job_data = job_queue.get(timeout=30) + client_id = job_data.get('extra_data', {}).get('client_id') + client_metadata = job_data.get('client_sockets_metadata', {}) + + sockets_metadata = {client_id: client_metadata} if client_id and client_metadata else {} + ipc_server = IPCMessageServer(message_queue, client_id, sockets_metadata) + server = WorkerServer(ipc_server) + + def check_interrupt(): + if interrupt_event.is_set(): + raise comfy.model_management.InterruptProcessingException() + + worker = NativeWorker(server, interrupt_checker=check_interrupt) + + import comfy.execution_core + comfy.execution_core._active_worker = worker + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + node_count = loop.run_until_complete(worker.initialize()) + logging.info(f"Worker {worker_id} loaded {node_count} node types") + + result = worker.execute_prompt( + job_data['prompt'], + job_data['prompt_id'], + job_data.get('extra_data', {}), + job_data.get('execute_outputs', []) + ) + result['worker_id'] = worker_id + + logging.info(f"Worker {worker_id} completed successfully") + result_queue.put(result) + + except Exception as e: + logging.error(f"Worker {worker_id} failed: {e}\n{traceback.format_exc()}") + result_queue.put({ + 'success': False, + 'error': str(e), + 'traceback': traceback.format_exc(), + 'history_result': {}, + 'status_messages': [], + 'worker_id': worker_id, + 'prompt_id': job_data.get('prompt_id', 'unknown') if job_data else 'unknown' + }) + + finally: + logging.info(f"Worker {worker_id} exiting") diff --git a/main.py b/main.py index 0d02a087b8a0..006b5a3081fb 100644 --- a/main.py +++ b/main.py @@ -11,9 +11,6 @@ import utils.extra_config import logging import sys -from comfy_execution.progress import get_progress_state -from comfy_execution.utils import get_executing_context -from comfy_api import feature_flags if __name__ == "__main__": @@ -176,16 +173,22 @@ def execute_script(script_path): import comfy.utils -import execution import server -from protocol import BinaryEventTypes -import nodes -import comfy.model_management import comfyui_version import app.logger -import hook_breaker_ac10a0 + +# Import modules needed for server operation +# GPU initialization happens lazily when GPU functions are called +# In subprocess mode, main process won't call GPU functions - workers will +if __name__ == "__main__": + import execution + import nodes + import comfy.model_management + def cuda_malloc_warning(): + if args.use_subprocess_workers: + return device = comfy.model_management.get_torch_device() device_name = comfy.model_management.get_torch_device_name(device) cuda_malloc_warning = False @@ -197,84 +200,6 @@ def cuda_malloc_warning(): logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") -def prompt_worker(q, server_instance): - current_time: float = 0.0 - cache_type = execution.CacheType.CLASSIC - if args.cache_lru > 0: - cache_type = execution.CacheType.LRU - elif args.cache_ram > 0: - cache_type = execution.CacheType.RAM_PRESSURE - elif args.cache_none: - cache_type = execution.CacheType.NONE - - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } ) - last_gc_collect = 0 - need_gc = False - gc_collect_interval = 10.0 - - while True: - timeout = 1000.0 - if need_gc: - timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) - - queue_item = q.get(timeout=timeout) - if queue_item is not None: - item, item_id = queue_item - execution_start_time = time.perf_counter() - prompt_id = item[1] - server_instance.last_prompt_id = prompt_id - - sensitive = item[5] - extra_data = item[3].copy() - for k in sensitive: - extra_data[k] = sensitive[k] - - e.execute(item[2], prompt_id, extra_data, item[4]) - need_gc = True - - remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] - q.task_done(item_id, - e.history_result, - status=execution.PromptQueue.ExecutionStatus( - status_str='success' if e.success else 'error', - completed=e.success, - messages=e.status_messages), process_item=remove_sensitive) - if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) - - current_time = time.perf_counter() - execution_time = current_time - execution_start_time - - # Log Time in a more readable way after 10 minutes - if execution_time > 600: - execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time)) - logging.info(f"Prompt executed in {execution_time}") - else: - logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) - - flags = q.get_flags() - free_memory = flags.get("free_memory", False) - - if flags.get("unload_models", free_memory): - comfy.model_management.unload_all_models() - need_gc = True - last_gc_collect = 0 - - if free_memory: - e.reset() - need_gc = True - last_gc_collect = 0 - - if need_gc: - current_time = time.perf_counter() - if (current_time - last_gc_collect) > gc_collect_interval: - gc.collect() - comfy.model_management.soft_empty_cache() - last_gc_collect = current_time - need_gc = False - hook_breaker_ac10a0.restore_functions() - - async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): addresses = [] for addr in address.split(","): @@ -283,37 +208,6 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop() ) -def hijack_progress(server_instance): - def hook(value, total, preview_image, prompt_id=None, node_id=None): - executing_context = get_executing_context() - if prompt_id is None and executing_context is not None: - prompt_id = executing_context.prompt_id - if node_id is None and executing_context is not None: - node_id = executing_context.node_id - comfy.model_management.throw_exception_if_processing_interrupted() - if prompt_id is None: - prompt_id = server_instance.last_prompt_id - if node_id is None: - node_id = server_instance.last_node_id - progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} - get_progress_state().update_progress(node_id, value, total, preview_image) - - server_instance.send_sync("progress", progress, server_instance.client_id) - if preview_image is not None: - # Only send old method if client doesn't support preview metadata - if not feature_flags.supports_feature( - server_instance.sockets_metadata, - server_instance.client_id, - "supports_preview_metadata", - ): - server_instance.send_sync( - BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, - preview_image, - server_instance.client_id, - ) - - comfy.utils.set_progress_bar_global_hook(hook) - def cleanup_temp(): temp_dir = folder_paths.get_temp_directory() @@ -356,20 +250,16 @@ def start_comfyui(asyncio_loop=None): if args.enable_manager and not args.disable_manager_ui: comfyui_manager.start() - hook_breaker_ac10a0.save_functions() - asyncio_loop.run_until_complete(nodes.init_extra_nodes( - init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, - init_api_nodes=not args.disable_api_nodes - )) - hook_breaker_ac10a0.restore_functions() + from comfy.execution_core import create_worker, prompt_worker + worker = create_worker(prompt_server) + node_count = asyncio_loop.run_until_complete(worker.initialize()) + logging.info(f"Loaded {node_count} node types") + threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, worker), name="PromptWorker").start() cuda_malloc_warning() setup_database() prompt_server.add_routes() - hijack_progress(prompt_server) - - threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start() if args.quick_test_for_ci: exit(0) diff --git a/nodes.py b/nodes.py index 8d28a725d91e..2df22c05366c 100644 --- a/nodes.py +++ b/nodes.py @@ -50,7 +50,8 @@ def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() def interrupt_processing(value=True): - comfy.model_management.interrupt_current_processing(value) + from comfy.execution_core import interrupt_processing as core_interrupt + core_interrupt(value) MAX_RESOLUTION=16384