Skip to content

Commit a82562e

Browse files
Add a level 3 sleep/wake_up that offloads tensors to disk
Co-authored-by: aavarghese <[email protected]> Co-authored-by: manoelmarques <[email protected]> Signed-off-by: Manoel Marques <[email protected]>
1 parent e246ad6 commit a82562e

File tree

5 files changed

+159
-2
lines changed

5 files changed

+159
-2
lines changed

tests/basic_correctness/test_cumem.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,30 @@ def test_end_to_end(model: str):
171171
# cmp output
172172
assert output[0].outputs[0].text == output3[0].outputs[0].text
173173

174+
# test sleep level 3 here.
175+
llm.sleep(level=3)
176+
177+
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
178+
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
179+
# now the memory usage is mostly cudagraph memory pool,
180+
# and it should be less than the model weights (1B model, 2GiB weights)
181+
182+
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
183+
# is captured but cannot be releasesd from PyTorch due to a known bug,
184+
# therefore high memory usage after `llm.sleep` is called is expected.
185+
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
186+
# in V1.
187+
if use_v1:
188+
assert used_bytes < 7 * GiB_bytes
189+
else:
190+
assert used_bytes < 2 * GiB_bytes
191+
192+
llm.wake_up()
193+
output2 = llm.generate(prompt, sampling_params)
194+
195+
# cmp output
196+
assert output[0].outputs[0].text == output2[0].outputs[0].text
197+
174198

175199
@create_new_process_for_each_test()
176200
def test_deep_sleep():

tests/entrypoints/openai/test_sleep.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,15 @@ def test_sleep_mode():
5959
response = requests.get(remote_server.url_for("is_sleeping"))
6060
assert response.status_code == 200
6161
assert response.json().get("is_sleeping") is False
62+
63+
response = requests.post(remote_server.url_for("/sleep"), data={"level": "3"})
64+
assert response.status_code == 200
65+
response = requests.get(remote_server.url_for("/is_sleeping"))
66+
assert response.status_code == 200
67+
assert response.json().get("is_sleeping") is True
68+
69+
response = requests.post(remote_server.url_for("/wake_up"))
70+
assert response.status_code == 200
71+
response = requests.get(remote_server.url_for("/is_sleeping"))
72+
assert response.status_code == 200
73+
assert response.json().get("is_sleeping") is False

vllm/device_allocator/cumem.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,68 @@
88
# both of them failed because of cuda context mismatch.
99
# not sure why, they are created from a different context.
1010
# the only successful approach is to call cuda driver API in C.
11+
import contextlib
12+
import ctypes
1113
import dataclasses
1214
import gc
15+
import io
16+
import mmap
1317
import os
18+
import struct
19+
import uuid
1420
from contextlib import contextmanager
1521
from typing import Any, Callable, Optional, Union
1622

1723
import torch
1824

25+
import vllm.envs as envs
1926
from vllm.logger import init_logger
2027
from vllm.utils import is_pin_memory_available
2128

2229
logger = init_logger(__name__)
2330

2431

32+
def _copy_from_cuda_to_bytes(scr_ptr: int, size_in_bytes: int) -> bytes:
33+
dest_ptr = ctypes.create_string_buffer(size_in_bytes)
34+
libcudart.cudaMemcpy(dest_ptr, scr_ptr, size_in_bytes)
35+
return bytes(dest_ptr)
36+
37+
38+
def _copy_from_bytes_to_cuda(dest_ptr: int, data: bytes) -> None:
39+
# reserve space for 0 termination
40+
scr_ptr = ctypes.create_string_buffer(data, len(data))
41+
libcudart.cudaMemcpy(dest_ptr, scr_ptr, len(data))
42+
43+
44+
def _write_bytes(data: bytes, binary_file: io.BufferedWriter) -> None:
45+
# Pack the length as a 4-byte unsigned integer (little-endian)
46+
data_len = len(data)
47+
header = struct.pack("<I", data_len)
48+
binary_file.write(header)
49+
if data_len > 0:
50+
binary_file.write(data)
51+
52+
53+
def _read_bytes(mmap_obj: mmap.mmap) -> bytes:
54+
header = mmap_obj.read(4)
55+
if not header:
56+
raise ValueError("Missing header read")
57+
58+
if len(header) != 4:
59+
raise ValueError("Incomplete header read")
60+
61+
data_len = struct.unpack("<I", header)[0]
62+
if data_len == 0:
63+
return b""
64+
65+
data = mmap_obj.read(data_len)
66+
67+
if len(data) != data_len:
68+
raise ValueError("Incomplete data read")
69+
70+
return data
71+
72+
2573
def find_loaded_library(lib_name) -> Optional[str]:
2674
"""
2775
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
@@ -166,6 +214,8 @@ def __init__(self):
166214
self.python_malloc_callback = self._python_malloc_callback
167215
self.python_free_callback = self._python_free_callback
168216

217+
self.cache_filepath = ""
218+
169219
def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
170220
"""
171221
Internal method to store the allocation data
@@ -197,7 +247,27 @@ def _python_free_callback(self, ptr: int) -> HandleType:
197247
)
198248
return data.handle
199249

200-
def sleep(self, offload_tags: Optional[Union[tuple[str, ...], str]] = None) -> None:
250+
def _delete_cache_file(self):
251+
"""
252+
Remove sleep cache file if it exists
253+
"""
254+
if self.cache_filepath != "":
255+
filepath = self.cache_filepath
256+
self.cache_filepath = ""
257+
try:
258+
with contextlib.suppress(FileNotFoundError):
259+
os.remove(filepath)
260+
logger.info("cache file %s deleted", filepath)
261+
except Exception as e:
262+
logger.warning(
263+
"failed to delete sleep cache file %s.", filepath, exc_info=e
264+
)
265+
266+
def sleep(
267+
self,
268+
level: Optional[int] = 1,
269+
offload_tags: Optional[Union[tuple[str, ...], str]] = None,
270+
) -> None:
201271
"""
202272
Put the allocator in sleep mode.
203273
All data in the memory allocation with the specified tag will be
@@ -218,6 +288,31 @@ def sleep(self, offload_tags: Optional[Union[tuple[str, ...], str]] = None) -> N
218288
total_bytes = 0
219289
backup_bytes = 0
220290

291+
# remove previous file if exists
292+
self._delete_cache_file()
293+
294+
# level 3 write weights to file
295+
if level == 3:
296+
unique_id = uuid.uuid4().hex
297+
self.cache_filepath = os.path.join(
298+
envs.VLLM_CACHE_ROOT, f"sleep_cache_{unique_id}.bin"
299+
)
300+
logger.info(
301+
"sleep level %d writing to cache file %s", level, self.cache_filepath
302+
)
303+
with open(self.cache_filepath, "wb") as binary_file:
304+
for ptr, data in self.pointer_to_data.items():
305+
handle = data.handle
306+
if data.tag in offload_tags:
307+
size_in_bytes = handle[1]
308+
data = _copy_from_cuda_to_bytes(ptr, size_in_bytes)
309+
_write_bytes(data, binary_file)
310+
else:
311+
_write_bytes(b"", binary_file)
312+
unmap_and_release(handle)
313+
return
314+
315+
# handle other levels
221316
for ptr, data in self.pointer_to_data.items():
222317
handle = data.handle
223318
total_bytes += handle[1]
@@ -257,6 +352,25 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
257352
back to GPU memory. If None, all memory allocation will be loaded
258353
back to GPU memory.
259354
"""
355+
if self.cache_filepath != "":
356+
logger.info("wake_up reading from cache file %s", self.cache_filepath)
357+
with (
358+
open(self.cache_filepath, "rb") as bin_file,
359+
mmap.mmap(
360+
bin_file.fileno(), length=0, access=mmap.ACCESS_READ
361+
) as mmap_obj,
362+
):
363+
for ptr, data in self.pointer_to_data.items():
364+
handle = data.handle
365+
create_and_map(handle)
366+
data = _read_bytes(mmap_obj)
367+
if len(data) > 0:
368+
_copy_from_bytes_to_cuda(ptr, data)
369+
370+
# remove file
371+
self._delete_cache_file()
372+
return
373+
260374
for ptr, data in self.pointer_to_data.items():
261375
if tags is None or data.tag in tags:
262376
handle = data.handle

vllm/entrypoints/llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,6 +1487,11 @@ def sleep(self, level: int = 1):
14871487
sleep is good for sleeping and waking up the engine to run a
14881488
different model or update the model, where previous model
14891489
weights are not needed. It reduces CPU memory pressure.
1490+
Level 3 sleep will offload the model weights to disk and
1491+
discard the kv cache. The model weights are not backed up in
1492+
CPU memory. The content of kv cache is forgotten. Level 3
1493+
sleep helps use minimum CPU memory and loads efficiently
1494+
from disk when woken up.
14901495
"""
14911496
self.reset_prefix_cache()
14921497
self.llm_engine.sleep(level=level)

vllm/v1/worker/gpu_worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def sleep(self, level: int = 1) -> None:
120120
}
121121

122122
allocator = CuMemAllocator.get_instance()
123-
allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
123+
allocator.sleep(
124+
level, offload_tags=("weights",) if level == 1 or level == 3 else tuple()
125+
)
124126
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
125127
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
126128
used_bytes = total - free_bytes_after_sleep

0 commit comments

Comments
 (0)