88# both of them failed because of cuda context mismatch.
99# not sure why, they are created from a different context.
1010# the only successful approach is to call cuda driver API in C.
11+ import contextlib
12+ import ctypes
1113import dataclasses
1214import gc
15+ import io
16+ import mmap
1317import os
18+ import struct
19+ import uuid
1420from contextlib import contextmanager
1521from typing import Any , Callable , Optional , Union
1622
1723import torch
1824
25+ import vllm .envs as envs
1926from vllm .logger import init_logger
2027from vllm .utils import is_pin_memory_available
2128
2229logger = init_logger (__name__ )
2330
2431
32+ def _copy_from_cuda_to_bytes (scr_ptr : int , size_in_bytes : int ) -> bytes :
33+ dest_ptr = ctypes .create_string_buffer (size_in_bytes )
34+ libcudart .cudaMemcpy (dest_ptr , scr_ptr , size_in_bytes )
35+ return bytes (dest_ptr )
36+
37+
38+ def _copy_from_bytes_to_cuda (dest_ptr : int , data : bytes ) -> None :
39+ # reserve space for 0 termination
40+ scr_ptr = ctypes .create_string_buffer (data , len (data ))
41+ libcudart .cudaMemcpy (dest_ptr , scr_ptr , len (data ))
42+
43+
44+ def _write_bytes (data : bytes , binary_file : io .BufferedWriter ) -> None :
45+ # Pack the length as a 4-byte unsigned integer (little-endian)
46+ data_len = len (data )
47+ header = struct .pack ("<I" , data_len )
48+ binary_file .write (header )
49+ if data_len > 0 :
50+ binary_file .write (data )
51+
52+
53+ def _read_bytes (mmap_obj : mmap .mmap ) -> bytes :
54+ header = mmap_obj .read (4 )
55+ if not header :
56+ raise ValueError ("Missing header read" )
57+
58+ if len (header ) != 4 :
59+ raise ValueError ("Incomplete header read" )
60+
61+ data_len = struct .unpack ("<I" , header )[0 ]
62+ if data_len == 0 :
63+ return b""
64+
65+ data = mmap_obj .read (data_len )
66+
67+ if len (data ) != data_len :
68+ raise ValueError ("Incomplete data read" )
69+
70+ return data
71+
72+
2573def find_loaded_library (lib_name ) -> Optional [str ]:
2674 """
2775 According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
@@ -166,6 +214,8 @@ def __init__(self):
166214 self .python_malloc_callback = self ._python_malloc_callback
167215 self .python_free_callback = self ._python_free_callback
168216
217+ self .cache_filepath = ""
218+
169219 def _python_malloc_callback (self , allocation_handle : HandleType ) -> None :
170220 """
171221 Internal method to store the allocation data
@@ -197,7 +247,27 @@ def _python_free_callback(self, ptr: int) -> HandleType:
197247 )
198248 return data .handle
199249
200- def sleep (self , offload_tags : Optional [Union [tuple [str , ...], str ]] = None ) -> None :
250+ def _delete_cache_file (self ):
251+ """
252+ Remove sleep cache file if it exists
253+ """
254+ if self .cache_filepath != "" :
255+ filepath = self .cache_filepath
256+ self .cache_filepath = ""
257+ try :
258+ with contextlib .suppress (FileNotFoundError ):
259+ os .remove (filepath )
260+ logger .info ("cache file %s deleted" , filepath )
261+ except Exception as e :
262+ logger .warning (
263+ "failed to delete sleep cache file %s." , filepath , exc_info = e
264+ )
265+
266+ def sleep (
267+ self ,
268+ level : Optional [int ] = 1 ,
269+ offload_tags : Optional [Union [tuple [str , ...], str ]] = None ,
270+ ) -> None :
201271 """
202272 Put the allocator in sleep mode.
203273 All data in the memory allocation with the specified tag will be
@@ -218,6 +288,31 @@ def sleep(self, offload_tags: Optional[Union[tuple[str, ...], str]] = None) -> N
218288 total_bytes = 0
219289 backup_bytes = 0
220290
291+ # remove previous file if exists
292+ self ._delete_cache_file ()
293+
294+ # level 3 write weights to file
295+ if level == 3 :
296+ unique_id = uuid .uuid4 ().hex
297+ self .cache_filepath = os .path .join (
298+ envs .VLLM_CACHE_ROOT , f"sleep_cache_{ unique_id } .bin"
299+ )
300+ logger .info (
301+ "sleep level %d writing to cache file %s" , level , self .cache_filepath
302+ )
303+ with open (self .cache_filepath , "wb" ) as binary_file :
304+ for ptr , data in self .pointer_to_data .items ():
305+ handle = data .handle
306+ if data .tag in offload_tags :
307+ size_in_bytes = handle [1 ]
308+ data = _copy_from_cuda_to_bytes (ptr , size_in_bytes )
309+ _write_bytes (data , binary_file )
310+ else :
311+ _write_bytes (b"" , binary_file )
312+ unmap_and_release (handle )
313+ return
314+
315+ # handle other levels
221316 for ptr , data in self .pointer_to_data .items ():
222317 handle = data .handle
223318 total_bytes += handle [1 ]
@@ -257,6 +352,25 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
257352 back to GPU memory. If None, all memory allocation will be loaded
258353 back to GPU memory.
259354 """
355+ if self .cache_filepath != "" :
356+ logger .info ("wake_up reading from cache file %s" , self .cache_filepath )
357+ with (
358+ open (self .cache_filepath , "rb" ) as bin_file ,
359+ mmap .mmap (
360+ bin_file .fileno (), length = 0 , access = mmap .ACCESS_READ
361+ ) as mmap_obj ,
362+ ):
363+ for ptr , data in self .pointer_to_data .items ():
364+ handle = data .handle
365+ create_and_map (handle )
366+ data = _read_bytes (mmap_obj )
367+ if len (data ) > 0 :
368+ _copy_from_bytes_to_cuda (ptr , data )
369+
370+ # remove file
371+ self ._delete_cache_file ()
372+ return
373+
260374 for ptr , data in self .pointer_to_data .items ():
261375 if tags is None or data .tag in tags :
262376 handle = data .handle
0 commit comments