88# both of them failed because of cuda context mismatch.
99# not sure why, they are created from a different context.
1010# the only successful approach is to call cuda driver API in C.
11+ import contextlib
12+ import ctypes
1113import dataclasses
1214import gc
15+ import io
16+ import mmap
1317import os
18+ import struct
19+ import uuid
1420from collections .abc import Callable
1521from contextlib import contextmanager
1622from typing import Any
1723
1824import torch
1925
26+ import vllm .envs as envs
2027from vllm .logger import init_logger
2128from vllm .utils .platform_utils import is_pin_memory_available
2229
2330logger = init_logger (__name__ )
2431
2532
33+ def _copy_from_cuda_to_bytes (scr_ptr : int , size_in_bytes : int ) -> bytes :
34+ dest_ptr = ctypes .create_string_buffer (size_in_bytes )
35+ libcudart .cudaMemcpy (dest_ptr , scr_ptr , size_in_bytes )
36+ return bytes (dest_ptr )
37+
38+
39+ def _copy_from_bytes_to_cuda (dest_ptr : int , data : bytes ) -> None :
40+ # reserve space for 0 termination
41+ scr_ptr = ctypes .create_string_buffer (data , len (data ))
42+ libcudart .cudaMemcpy (dest_ptr , scr_ptr , len (data ))
43+
44+
45+ def _write_bytes (data : bytes , binary_file : io .BufferedWriter ) -> None :
46+ # Pack the length as a 4-byte unsigned integer (little-endian)
47+ data_len = len (data )
48+ header = struct .pack ("<I" , data_len )
49+ binary_file .write (header )
50+ if data_len > 0 :
51+ binary_file .write (data )
52+
53+
54+ def _read_bytes (mmap_obj : mmap .mmap ) -> bytes :
55+ header = mmap_obj .read (4 )
56+ if not header :
57+ raise ValueError ("Missing header read" )
58+
59+ if len (header ) != 4 :
60+ raise ValueError ("Incomplete header read" )
61+
62+ data_len = struct .unpack ("<I" , header )[0 ]
63+ if data_len == 0 :
64+ return b""
65+
66+ data = mmap_obj .read (data_len )
67+
68+ if len (data ) != data_len :
69+ raise ValueError ("Incomplete data read" )
70+
71+ return data
72+
73+
2674def find_loaded_library (lib_name ) -> str | None :
2775 """
2876 According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
@@ -167,6 +215,8 @@ def __init__(self):
167215 self .python_malloc_callback = self ._python_malloc_callback
168216 self .python_free_callback = self ._python_free_callback
169217
218+ self .cache_filepath = ""
219+
170220 def _python_malloc_callback (self , allocation_handle : HandleType ) -> None :
171221 """
172222 Internal method to store the allocation data
@@ -198,7 +248,25 @@ def _python_free_callback(self, ptr: int) -> HandleType:
198248 )
199249 return data .handle
200250
201- def sleep (self , offload_tags : tuple [str , ...] | str | None = None ) -> None :
251+ def _delete_cache_file (self ):
252+ """
253+ Remove sleep cache file if it exists
254+ """
255+ if self .cache_filepath != "" :
256+ filepath = self .cache_filepath
257+ self .cache_filepath = ""
258+ try :
259+ with contextlib .suppress (FileNotFoundError ):
260+ os .remove (filepath )
261+ logger .info ("cache file %s deleted" , filepath )
262+ except Exception as e :
263+ logger .warning (
264+ "failed to delete sleep cache file %s." , filepath , exc_info = e
265+ )
266+
267+ def sleep (
268+ self , level : int = 1 , offload_tags : tuple [str , ...] | str | None = None
269+ ) -> None :
202270 """
203271 Put the allocator in sleep mode.
204272 All data in the memory allocation with the specified tag will be
@@ -219,6 +287,31 @@ def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
219287 total_bytes = 0
220288 backup_bytes = 0
221289
290+ # remove previous file if exists
291+ self ._delete_cache_file ()
292+
293+ # level 3 write weights to file
294+ if level == 3 :
295+ unique_id = uuid .uuid4 ().hex
296+ self .cache_filepath = os .path .join (
297+ envs .VLLM_CACHE_ROOT , f"sleep_cache_{ unique_id } .bin"
298+ )
299+ logger .info (
300+ "sleep level %d writing to cache file %s" , level , self .cache_filepath
301+ )
302+ with open (self .cache_filepath , "wb" ) as binary_file :
303+ for ptr , data in self .pointer_to_data .items ():
304+ handle = data .handle
305+ if data .tag in offload_tags :
306+ size_in_bytes = handle [1 ]
307+ data = _copy_from_cuda_to_bytes (ptr , size_in_bytes )
308+ _write_bytes (data , binary_file )
309+ else :
310+ _write_bytes (b"" , binary_file )
311+ unmap_and_release (handle )
312+ return
313+
314+ # handle other levels
222315 for ptr , data in self .pointer_to_data .items ():
223316 handle = data .handle
224317 total_bytes += handle [1 ]
@@ -258,6 +351,25 @@ def wake_up(self, tags: list[str] | None = None) -> None:
258351 back to GPU memory. If None, all memory allocation will be loaded
259352 back to GPU memory.
260353 """
354+ if self .cache_filepath != "" :
355+ logger .info ("wake_up reading from cache file %s" , self .cache_filepath )
356+ with (
357+ open (self .cache_filepath , "rb" ) as bin_file ,
358+ mmap .mmap (
359+ bin_file .fileno (), length = 0 , access = mmap .ACCESS_READ
360+ ) as mmap_obj ,
361+ ):
362+ for ptr , data in self .pointer_to_data .items ():
363+ handle = data .handle
364+ create_and_map (handle )
365+ data = _read_bytes (mmap_obj )
366+ if len (data ) > 0 :
367+ _copy_from_bytes_to_cuda (ptr , data )
368+
369+ # remove file
370+ self ._delete_cache_file ()
371+ return
372+
261373 for ptr , data in self .pointer_to_data .items ():
262374 if tags is None or data .tag in tags :
263375 handle = data .handle
0 commit comments