Skip to content

Commit 7cecb6d

Browse files
committed
worker/executor with subprocess executor
1 parent 56fa7db commit 7cecb6d

File tree

7 files changed

+543
-127
lines changed

7 files changed

+543
-127
lines changed

comfy/cli_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ class LatentPreviewMethod(enum.Enum):
145145
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
146146

147147
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
148+
parser.add_argument("--use-subprocess-workers", action="store_true", help="Execute each prompt in an isolated subprocess with complete GPU/ROCm context reset. Ensures clean state between jobs but adds startup overhead.")
149+
parser.add_argument("--subprocess-timeout", type=int, default=600, help="Timeout in seconds for subprocess execution (default: 600, only used with --use-subprocess-workers).")
148150
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
149151

150152
class PerformanceFeature(enum.Enum):

comfy/execution_core.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Core execution logic shared between normal and subprocess execution modes."""
2+
3+
import logging
4+
import time
5+
6+
_active_worker = None
7+
8+
9+
def create_worker(server_instance):
10+
"""Create worker backend. Returns NativeWorker or SubprocessWorker."""
11+
global _active_worker
12+
from comfy.cli_args import args
13+
14+
server = WorkerServer(server_instance)
15+
16+
if args.use_subprocess_workers:
17+
from comfy.worker_process import SubprocessWorker
18+
worker = SubprocessWorker(server, timeout=args.subprocess_timeout)
19+
else:
20+
from comfy.worker_native import NativeWorker
21+
worker = NativeWorker(server)
22+
23+
_active_worker = worker
24+
return worker
25+
26+
27+
async def init_execution_environment():
28+
"""Load nodes and custom nodes. Returns number of node types loaded."""
29+
import nodes
30+
from comfy.cli_args import args
31+
32+
await nodes.init_extra_nodes(
33+
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
34+
init_api_nodes=not args.disable_api_nodes
35+
)
36+
return len(nodes.NODE_CLASS_MAPPINGS)
37+
38+
39+
def setup_progress_hook(server_instance, interrupt_checker):
40+
"""Set up global progress hook. interrupt_checker must raise on interrupt."""
41+
import comfy.utils
42+
from comfy_execution.progress import get_progress_state
43+
from comfy_execution.utils import get_executing_context
44+
45+
def hook(value, total, preview_image, prompt_id=None, node_id=None):
46+
ctx = get_executing_context()
47+
if ctx:
48+
prompt_id = prompt_id or ctx.prompt_id
49+
node_id = node_id or ctx.node_id
50+
51+
interrupt_checker()
52+
53+
prompt_id = prompt_id or server_instance.last_prompt_id
54+
node_id = node_id or server_instance.last_node_id
55+
56+
get_progress_state().update_progress(node_id, value, total, preview_image)
57+
server_instance.send_sync("progress", {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}, server_instance.client_id)
58+
59+
comfy.utils.set_progress_bar_global_hook(hook)
60+
61+
62+
class WorkerServer:
63+
"""Protocol boundary: client_id, last_node_id, last_prompt_id, sockets_metadata, send_sync(), queue_updated()"""
64+
65+
_WRITABLE = {'client_id', 'last_node_id', 'last_prompt_id'}
66+
67+
def __init__(self, server):
68+
object.__setattr__(self, '_server', server)
69+
70+
def __setattr__(self, name, value):
71+
if name in self._WRITABLE:
72+
setattr(self._server, name, value)
73+
else:
74+
raise AttributeError(f"WorkerServer does not accept attribute '{name}'")
75+
76+
@property
77+
def client_id(self):
78+
return self._server.client_id
79+
80+
@property
81+
def last_node_id(self):
82+
return self._server.last_node_id
83+
84+
@property
85+
def last_prompt_id(self):
86+
return self._server.last_prompt_id
87+
88+
@property
89+
def sockets_metadata(self):
90+
return self._server.sockets_metadata
91+
92+
def send_sync(self, event, data, sid=None):
93+
self._server.send_sync(event, data, sid or self.client_id)
94+
95+
def queue_updated(self):
96+
self._server.queue_updated()
97+
98+
def interrupt_processing(value=True):
99+
_active_worker.interrupt(value)
100+
101+
102+
def _strip_sensitive(prompt):
103+
return prompt[:5] + prompt[6:]
104+
105+
106+
def prompt_worker(q, worker):
107+
"""Main prompt execution loop."""
108+
import execution
109+
110+
server = worker.server_instance
111+
112+
while True:
113+
queue_item = q.get(timeout=worker.get_gc_timeout())
114+
if queue_item is not None:
115+
item, item_id = queue_item
116+
start_time = time.perf_counter()
117+
prompt_id = item[1]
118+
server.last_prompt_id = prompt_id
119+
120+
extra_data = {**item[3], **item[5]}
121+
122+
result = worker.execute_prompt(item[2], prompt_id, extra_data, item[4], server=server)
123+
worker.mark_needs_gc()
124+
125+
q.task_done(
126+
item_id,
127+
result['history_result'],
128+
status=execution.PromptQueue.ExecutionStatus(
129+
status_str='success' if result['success'] else 'error',
130+
completed=result['success'],
131+
messages=result['status_messages']
132+
),
133+
process_item=_strip_sensitive
134+
)
135+
136+
if server.client_id is not None:
137+
server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server.client_id)
138+
139+
elapsed = time.perf_counter() - start_time
140+
if elapsed > 600:
141+
logging.info(f"Prompt executed in {time.strftime('%H:%M:%S', time.gmtime(elapsed))}")
142+
else:
143+
logging.info(f"Prompt executed in {elapsed:.2f} seconds")
144+
145+
worker.handle_flags(q.get_flags())

comfy/worker_native.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Native (in-process) worker for prompt execution."""
2+
3+
import time
4+
import gc
5+
6+
7+
class NativeWorker:
8+
"""Executes prompts in the same process as the server."""
9+
10+
def __init__(self, server_instance, interrupt_checker=None):
11+
self.server_instance = server_instance
12+
self.interrupt_checker = interrupt_checker
13+
self.executor = None
14+
self.last_gc_collect = 0
15+
self.need_gc = False
16+
self.gc_collect_interval = 10.0
17+
18+
async def initialize(self):
19+
"""Load nodes and set up executor. Returns node count."""
20+
from execution import PromptExecutor, CacheType
21+
from comfy.cli_args import args
22+
from comfy.execution_core import init_execution_environment, setup_progress_hook
23+
import comfy.model_management as mm
24+
import hook_breaker_ac10a0
25+
26+
hook_breaker_ac10a0.save_functions()
27+
try:
28+
node_count = await init_execution_environment()
29+
finally:
30+
hook_breaker_ac10a0.restore_functions()
31+
32+
interrupt_checker = self.interrupt_checker or mm.throw_exception_if_processing_interrupted
33+
setup_progress_hook(self.server_instance, interrupt_checker=interrupt_checker)
34+
35+
cache_type = CacheType.CLASSIC
36+
if args.cache_lru > 0:
37+
cache_type = CacheType.LRU
38+
elif args.cache_ram > 0:
39+
cache_type = CacheType.RAM_PRESSURE
40+
elif args.cache_none:
41+
cache_type = CacheType.NONE
42+
43+
self.executor = PromptExecutor(
44+
self.server_instance,
45+
cache_type=cache_type,
46+
cache_args={"lru": args.cache_lru, "ram": args.cache_ram}
47+
)
48+
return node_count
49+
50+
def execute_prompt(self, prompt, prompt_id, extra_data, execute_outputs, server=None):
51+
self.executor.execute(prompt, prompt_id, extra_data, execute_outputs)
52+
return {
53+
'success': self.executor.success,
54+
'history_result': self.executor.history_result,
55+
'status_messages': self.executor.status_messages,
56+
'prompt_id': prompt_id
57+
}
58+
59+
def handle_flags(self, flags):
60+
import comfy.model_management as mm
61+
import hook_breaker_ac10a0
62+
63+
free_memory = flags.get("free_memory", False)
64+
65+
if flags.get("unload_models", free_memory):
66+
mm.unload_all_models()
67+
self.need_gc = True
68+
self.last_gc_collect = 0
69+
70+
if free_memory:
71+
if self.executor:
72+
self.executor.reset()
73+
self.need_gc = True
74+
self.last_gc_collect = 0
75+
76+
if self.need_gc:
77+
current_time = time.perf_counter()
78+
if (current_time - self.last_gc_collect) > self.gc_collect_interval:
79+
gc.collect()
80+
mm.soft_empty_cache()
81+
self.last_gc_collect = current_time
82+
self.need_gc = False
83+
hook_breaker_ac10a0.restore_functions()
84+
85+
def interrupt(self, value=True):
86+
import comfy.model_management
87+
comfy.model_management.interrupt_current_processing(value)
88+
89+
def mark_needs_gc(self):
90+
self.need_gc = True
91+
92+
def get_gc_timeout(self):
93+
if self.need_gc:
94+
return max(self.gc_collect_interval - (time.perf_counter() - self.last_gc_collect), 0.0)
95+
return 1000.0

0 commit comments

Comments
 (0)