Skip to content

Commit b9e95d1

Browse files
Add a level 3 sleep/wake_up that offloads tensors to disk
Co-authored-by: aavarghese <[email protected]> Co-authored-by: manoelmarques <[email protected]> Signed-off-by: Manoel Marques <[email protected]>
1 parent 1bb17ec commit b9e95d1

File tree

5 files changed

+154
-2
lines changed

5 files changed

+154
-2
lines changed

tests/basic_correctness/test_cumem.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,27 @@ def test_end_to_end(model: str):
174174
# cmp output
175175
assert output[0].outputs[0].text == output3[0].outputs[0].text
176176

177+
# test sleep level 3 here.
178+
llm.sleep(level=3)
179+
180+
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
181+
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
182+
# now the memory usage is mostly cudagraph memory pool,
183+
# and it should be less than the model weights (1B model, 2GiB weights)
184+
185+
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
186+
# is captured but cannot be releasesd from PyTorch due to a known bug,
187+
# therefore high memory usage after `llm.sleep` is called is expected.
188+
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
189+
# in V1.
190+
assert used_bytes < 7 * GiB_bytes
191+
192+
llm.wake_up()
193+
output2 = llm.generate(prompt, sampling_params)
194+
195+
# cmp output
196+
assert output[0].outputs[0].text == output2[0].outputs[0].text
197+
177198

178199
@create_new_process_for_each_test()
179200
def test_deep_sleep():

tests/entrypoints/openai/test_sleep.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ def test_sleep_mode():
8585
assert weights_offloaded == 0
8686
assert discard_all == 0
8787

88+
response = requests.post(remote_server.url_for("sleep"), params={"level": "3"})
89+
assert response.status_code == 200
90+
response = requests.get(remote_server.url_for("is_sleeping"))
91+
assert response.status_code == 200
92+
assert response.json().get("is_sleeping") is True
93+
94+
response = requests.post(remote_server.url_for("wake_up"))
95+
assert response.status_code == 200
96+
response = requests.get(remote_server.url_for("is_sleeping"))
97+
assert response.status_code == 200
98+
assert response.json().get("is_sleeping") is False
99+
88100

89101
def _get_sleep_metrics_from_api(response: requests.Response):
90102
"""Return (awake, weights_offloaded, discard_all)"""

vllm/device_allocator/cumem.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,69 @@
88
# both of them failed because of cuda context mismatch.
99
# not sure why, they are created from a different context.
1010
# the only successful approach is to call cuda driver API in C.
11+
import contextlib
12+
import ctypes
1113
import dataclasses
1214
import gc
15+
import io
16+
import mmap
1317
import os
18+
import struct
19+
import uuid
1420
from collections.abc import Callable
1521
from contextlib import contextmanager
1622
from typing import Any
1723

1824
import torch
1925

26+
import vllm.envs as envs
2027
from vllm.logger import init_logger
2128
from vllm.utils.platform_utils import is_pin_memory_available
2229

2330
logger = init_logger(__name__)
2431

2532

33+
def _copy_from_cuda_to_bytes(scr_ptr: int, size_in_bytes: int) -> bytes:
34+
dest_ptr = ctypes.create_string_buffer(size_in_bytes)
35+
libcudart.cudaMemcpy(dest_ptr, scr_ptr, size_in_bytes)
36+
return bytes(dest_ptr)
37+
38+
39+
def _copy_from_bytes_to_cuda(dest_ptr: int, data: bytes) -> None:
40+
# reserve space for 0 termination
41+
scr_ptr = ctypes.create_string_buffer(data, len(data))
42+
libcudart.cudaMemcpy(dest_ptr, scr_ptr, len(data))
43+
44+
45+
def _write_bytes(data: bytes, binary_file: io.BufferedWriter) -> None:
46+
# Pack the length as a 4-byte unsigned integer (little-endian)
47+
data_len = len(data)
48+
header = struct.pack("<I", data_len)
49+
binary_file.write(header)
50+
if data_len > 0:
51+
binary_file.write(data)
52+
53+
54+
def _read_bytes(mmap_obj: mmap.mmap) -> bytes:
55+
header = mmap_obj.read(4)
56+
if not header:
57+
raise ValueError("Missing header read")
58+
59+
if len(header) != 4:
60+
raise ValueError("Incomplete header read")
61+
62+
data_len = struct.unpack("<I", header)[0]
63+
if data_len == 0:
64+
return b""
65+
66+
data = mmap_obj.read(data_len)
67+
68+
if len(data) != data_len:
69+
raise ValueError("Incomplete data read")
70+
71+
return data
72+
73+
2674
def find_loaded_library(lib_name) -> str | None:
2775
"""
2876
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
@@ -167,6 +215,8 @@ def __init__(self):
167215
self.python_malloc_callback = self._python_malloc_callback
168216
self.python_free_callback = self._python_free_callback
169217

218+
self.cache_filepath = ""
219+
170220
def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
171221
"""
172222
Internal method to store the allocation data
@@ -198,7 +248,25 @@ def _python_free_callback(self, ptr: int) -> HandleType:
198248
)
199249
return data.handle
200250

201-
def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
251+
def _delete_cache_file(self):
252+
"""
253+
Remove sleep cache file if it exists
254+
"""
255+
if self.cache_filepath != "":
256+
filepath = self.cache_filepath
257+
self.cache_filepath = ""
258+
try:
259+
with contextlib.suppress(FileNotFoundError):
260+
os.remove(filepath)
261+
logger.info("cache file %s deleted", filepath)
262+
except Exception as e:
263+
logger.warning(
264+
"failed to delete sleep cache file %s.", filepath, exc_info=e
265+
)
266+
267+
def sleep(
268+
self, level: int = 1, offload_tags: tuple[str, ...] | str | None = None
269+
) -> None:
202270
"""
203271
Put the allocator in sleep mode.
204272
All data in the memory allocation with the specified tag will be
@@ -219,6 +287,31 @@ def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
219287
total_bytes = 0
220288
backup_bytes = 0
221289

290+
# remove previous file if exists
291+
self._delete_cache_file()
292+
293+
# level 3 write weights to file
294+
if level == 3:
295+
unique_id = uuid.uuid4().hex
296+
self.cache_filepath = os.path.join(
297+
envs.VLLM_CACHE_ROOT, f"sleep_cache_{unique_id}.bin"
298+
)
299+
logger.info(
300+
"sleep level %d writing to cache file %s", level, self.cache_filepath
301+
)
302+
with open(self.cache_filepath, "wb") as binary_file:
303+
for ptr, data in self.pointer_to_data.items():
304+
handle = data.handle
305+
if data.tag in offload_tags:
306+
size_in_bytes = handle[1]
307+
data = _copy_from_cuda_to_bytes(ptr, size_in_bytes)
308+
_write_bytes(data, binary_file)
309+
else:
310+
_write_bytes(b"", binary_file)
311+
unmap_and_release(handle)
312+
return
313+
314+
# handle other levels
222315
for ptr, data in self.pointer_to_data.items():
223316
handle = data.handle
224317
total_bytes += handle[1]
@@ -258,6 +351,25 @@ def wake_up(self, tags: list[str] | None = None) -> None:
258351
back to GPU memory. If None, all memory allocation will be loaded
259352
back to GPU memory.
260353
"""
354+
if self.cache_filepath != "":
355+
logger.info("wake_up reading from cache file %s", self.cache_filepath)
356+
with (
357+
open(self.cache_filepath, "rb") as bin_file,
358+
mmap.mmap(
359+
bin_file.fileno(), length=0, access=mmap.ACCESS_READ
360+
) as mmap_obj,
361+
):
362+
for ptr, data in self.pointer_to_data.items():
363+
handle = data.handle
364+
create_and_map(handle)
365+
data = _read_bytes(mmap_obj)
366+
if len(data) > 0:
367+
_copy_from_bytes_to_cuda(ptr, data)
368+
369+
# remove file
370+
self._delete_cache_file()
371+
return
372+
261373
for ptr, data in self.pointer_to_data.items():
262374
if tags is None or data.tag in tags:
263375
handle = data.handle

vllm/entrypoints/llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,6 +1513,11 @@ def sleep(self, level: int = 1):
15131513
sleep is good for sleeping and waking up the engine to run a
15141514
different model or update the model, where previous model
15151515
weights are not needed. It reduces CPU memory pressure.
1516+
Level 3 sleep will offload the model weights to disk and
1517+
discard the kv cache. The model weights are not backed up in
1518+
CPU memory. The content of kv cache is forgotten. Level 3
1519+
sleep helps use minimum CPU memory and loads efficiently
1520+
from disk when woken up.
15161521
"""
15171522
self.reset_prefix_cache()
15181523
self.llm_engine.sleep(level=level)

vllm/v1/worker/gpu_worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def sleep(self, level: int = 1) -> None:
116116
}
117117

118118
allocator = CuMemAllocator.get_instance()
119-
allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
119+
allocator.sleep(
120+
level, offload_tags=("weights",) if level == 1 or level == 3 else tuple()
121+
)
120122
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
121123
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
122124
used_bytes = total - free_bytes_after_sleep

0 commit comments

Comments
 (0)