Skip to content

Commit 20ee418

Browse files
authored
[Model Runner V2] Minor fix for cudagraph_utils (#29256)
1 parent 389aa1b commit 20ee418

File tree

2 files changed

+6
-14
lines changed

2 files changed

+6
-14
lines changed

vllm/v1/worker/gpu/cudagraph_utils.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import gc
4-
from contextlib import contextmanager
3+
from unittest.mock import patch
54

65
import numpy as np
76
import torch
@@ -140,6 +139,7 @@ def capture_graph(
140139
attn_metadata,
141140
self.vllm_config,
142141
num_tokens=batch_size,
142+
cudagraph_runtime_mode=CUDAGraphMode.NONE,
143143
num_tokens_across_dp=num_tokens_across_dp,
144144
):
145145
hidden_states = model(
@@ -148,15 +148,16 @@ def capture_graph(
148148
)
149149
if self.hidden_states is None:
150150
self.hidden_states = torch.empty_like(hidden_states)
151-
torch.cuda.synchronize()
152151

153152
# Capture the graph.
154153
graph = torch.cuda.CUDAGraph()
155154
with (
155+
patch("torch.cuda.empty_cache", lambda: None),
156156
set_forward_context(
157157
attn_metadata,
158158
self.vllm_config,
159159
num_tokens=batch_size,
160+
cudagraph_runtime_mode=CUDAGraphMode.NONE,
160161
num_tokens_across_dp=num_tokens_across_dp,
161162
),
162163
torch.cuda.graph(graph, self.pool),
@@ -183,7 +184,7 @@ def capture(
183184
if is_global_first_rank():
184185
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
185186

186-
with freeze_gc(), graph_capture(device=self.device):
187+
with graph_capture(device=self.device):
187188
for batch_size in sizes_to_capture:
188189
self.capture_graph(
189190
batch_size,
@@ -199,13 +200,3 @@ def run(self, batch_size: int) -> torch.Tensor:
199200
self.graphs[batch_size].replay()
200201
assert self.hidden_states is not None
201202
return self.hidden_states[:batch_size]
202-
203-
204-
@contextmanager
205-
def freeze_gc():
206-
gc.collect()
207-
gc.freeze()
208-
try:
209-
yield
210-
finally:
211-
gc.unfreeze()

vllm/v1/worker/gpu/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def capture_model(self) -> int:
298298
return 0
299299

300300
start_time = time.perf_counter()
301+
torch.cuda.empty_cache()
301302
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
302303

303304
with self.maybe_setup_dummy_loras(self.lora_config):

0 commit comments

Comments
 (0)