88# both of them failed because of cuda context mismatch.
99# not sure why, they are created from a different context.
1010# the only successful approach is to call cuda driver API in C.
11+ import contextlib
12+ import ctypes
1113import dataclasses
1214import gc
15+ import io
16+ import mmap
1317import os
18+ import struct
19+ import uuid
1420from contextlib import contextmanager
1521from typing import Any , Callable , Optional , Union
1622
1723import torch
1824
25+ import vllm .envs as envs
1926from vllm .logger import init_logger
2027from vllm .utils import is_pin_memory_available
2128
2229logger = init_logger (__name__ )
2330
2431
32+ def _copy_from_cuda_to_bytes (scr_ptr : int , size_in_bytes : int ) -> bytes :
33+ dest_ptr = ctypes .create_string_buffer (size_in_bytes )
34+ libcudart .cudaMemcpy (dest_ptr , scr_ptr , size_in_bytes )
35+ return bytes (dest_ptr )
36+
37+
38+ def _copy_from_bytes_to_cuda (dest_ptr : int , data : bytes ) -> None :
39+ # reserve space for 0 termination
40+ scr_ptr = ctypes .create_string_buffer (data , len (data ))
41+ libcudart .cudaMemcpy (dest_ptr , scr_ptr , len (data ))
42+
43+
44+ def _write_bytes (data : bytes , binary_file : io .BufferedWriter ) -> None :
45+ # Pack the length as a 4-byte unsigned integer (little-endian)
46+ data_len = len (data )
47+ header = struct .pack ("<I" , data_len )
48+ binary_file .write (header )
49+ if data_len > 0 :
50+ binary_file .write (data )
51+
52+
53+ def _read_bytes (mmap_obj : mmap .mmap ) -> bytes :
54+ header = mmap_obj .read (4 )
55+ if not header :
56+ raise ValueError ("Missing header read" )
57+
58+ if len (header ) != 4 :
59+ raise ValueError ("Incomplete header read" )
60+
61+ data_len = struct .unpack ("<I" , header )[0 ]
62+ if data_len == 0 :
63+ return b''
64+
65+ data = mmap_obj .read (data_len )
66+
67+ if len (data ) != data_len :
68+ raise ValueError ("Incomplete data read" )
69+
70+ return data
71+
72+
2573def find_loaded_library (lib_name ) -> Optional [str ]:
2674 """
2775 According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
@@ -161,6 +209,8 @@ def __init__(self):
161209 self .python_malloc_callback = self ._python_malloc_callback
162210 self .python_free_callback = self ._python_free_callback
163211
212+ self .cache_filepath = ""
213+
164214 def _python_malloc_callback (self , allocation_handle : HandleType ) -> None :
165215 """
166216 Internal method to store the allocation data
@@ -185,8 +235,25 @@ def _python_free_callback(self, ptr: int) -> HandleType:
185235 data .handle [1 ], data .tag , ptr )
186236 return data .handle
187237
238+ def _delete_cache_file (self ):
239+ """
240+ Remove sleep cache file if it exists
241+ """
242+ if self .cache_filepath != "" :
243+ filepath = self .cache_filepath
244+ self .cache_filepath = ""
245+ try :
246+ with contextlib .suppress (FileNotFoundError ):
247+ os .remove (filepath )
248+ logger .info ("cache file %s deleted" , filepath )
249+ except Exception as e :
250+ logger .warning ("failed to delete sleep cache file %s." ,
251+ filepath ,
252+ exc_info = e )
253+
188254 def sleep (
189255 self ,
256+ level : Optional [int ] = 1 ,
190257 offload_tags : Optional [Union [tuple [str , ...],
191258 str ]] = None ) -> None :
192259 """
@@ -209,6 +276,29 @@ def sleep(
209276 total_bytes = 0
210277 backup_bytes = 0
211278
279+ # remove previous file if exists
280+ self ._delete_cache_file ()
281+
282+ # level 3 write weights to file
283+ if level == 3 :
284+ unique_id = uuid .uuid4 ().hex
285+ self .cache_filepath = os .path .join (envs .VLLM_CACHE_ROOT ,
286+ f"sleep_cache_{ unique_id } .bin" )
287+ logger .info ("sleep level %d writing to cache file %s" , level ,
288+ self .cache_filepath )
289+ with open (self .cache_filepath , 'wb' ) as binary_file :
290+ for ptr , data in self .pointer_to_data .items ():
291+ handle = data .handle
292+ if data .tag in offload_tags :
293+ size_in_bytes = handle [1 ]
294+ data = _copy_from_cuda_to_bytes (ptr , size_in_bytes )
295+ _write_bytes (data , binary_file )
296+ else :
297+ _write_bytes (b'' , binary_file )
298+ unmap_and_release (handle )
299+ return
300+
301+ # handle other levels
212302 for ptr , data in self .pointer_to_data .items ():
213303 handle = data .handle
214304 total_bytes += handle [1 ]
@@ -244,6 +334,23 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
244334 back to GPU memory. If None, all memory allocation will be loaded
245335 back to GPU memory.
246336 """
337+ if self .cache_filepath != "" :
338+ logger .info ("wake_up reading from cache file %s" ,
339+ self .cache_filepath )
340+ with open (self .cache_filepath , 'rb' ) as bin_file , \
341+ mmap .mmap (bin_file .fileno (),
342+ length = 0 , access = mmap .ACCESS_READ ) as mmap_obj :
343+ for ptr , data in self .pointer_to_data .items ():
344+ handle = data .handle
345+ create_and_map (handle )
346+ data = _read_bytes (mmap_obj )
347+ if len (data ) > 0 :
348+ _copy_from_bytes_to_cuda (ptr , data )
349+
350+ # remove file
351+ self ._delete_cache_file ()
352+ return
353+
247354 for ptr , data in self .pointer_to_data .items ():
248355 if tags is None or data .tag in tags :
249356 handle = data .handle
0 commit comments