11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- import gc
4- from contextlib import contextmanager
3+ from unittest .mock import patch
54
65import numpy as np
76import torch
@@ -140,6 +139,7 @@ def capture_graph(
140139 attn_metadata ,
141140 self .vllm_config ,
142141 num_tokens = batch_size ,
142+ cudagraph_runtime_mode = CUDAGraphMode .NONE ,
143143 num_tokens_across_dp = num_tokens_across_dp ,
144144 ):
145145 hidden_states = model (
@@ -148,15 +148,16 @@ def capture_graph(
148148 )
149149 if self .hidden_states is None :
150150 self .hidden_states = torch .empty_like (hidden_states )
151- torch .cuda .synchronize ()
152151
153152 # Capture the graph.
154153 graph = torch .cuda .CUDAGraph ()
155154 with (
155+ patch ("torch.cuda.empty_cache" , lambda : None ),
156156 set_forward_context (
157157 attn_metadata ,
158158 self .vllm_config ,
159159 num_tokens = batch_size ,
160+ cudagraph_runtime_mode = CUDAGraphMode .NONE ,
160161 num_tokens_across_dp = num_tokens_across_dp ,
161162 ),
162163 torch .cuda .graph (graph , self .pool ),
@@ -183,7 +184,7 @@ def capture(
183184 if is_global_first_rank ():
184185 sizes_to_capture = tqdm (sizes_to_capture , desc = "Capturing CUDA graphs" )
185186
186- with freeze_gc (), graph_capture (device = self .device ):
187+ with graph_capture (device = self .device ):
187188 for batch_size in sizes_to_capture :
188189 self .capture_graph (
189190 batch_size ,
@@ -199,13 +200,3 @@ def run(self, batch_size: int) -> torch.Tensor:
199200 self .graphs [batch_size ].replay ()
200201 assert self .hidden_states is not None
201202 return self .hidden_states [:batch_size ]
202-
203-
204- @contextmanager
205- def freeze_gc ():
206- gc .collect ()
207- gc .freeze ()
208- try :
209- yield
210- finally :
211- gc .unfreeze ()
0 commit comments