Skip to content

Commit ec38a73

Browse files
authored
[Model Runner V2] Use packed mask for prompt bin counts (#29756)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 21c2627 commit ec38a73

File tree

3 files changed

+35
-25
lines changed

3 files changed

+35
-25
lines changed

vllm/v1/worker/gpu/sample/metadata.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class SamplingMetadata:
2626

2727
# For penalties
2828
idx_mapping: torch.Tensor
29-
prompt_bin_counts: torch.Tensor
29+
prompt_bin_mask: torch.Tensor
3030
output_bin_counts: torch.Tensor
3131

3232
@classmethod
@@ -57,7 +57,7 @@ def make_dummy(
5757
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
5858
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
5959
# specialization and re-compilation at runtime.
60-
prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
60+
prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
6161
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
6262

6363
return cls(
@@ -71,7 +71,7 @@ def make_dummy(
7171
pos=pos,
7272
max_num_logprobs=max_num_logprobs,
7373
idx_mapping=idx_mapping,
74-
prompt_bin_counts=prompt_bin_counts,
74+
prompt_bin_mask=prompt_bin_mask,
7575
output_bin_counts=output_bin_counts,
7676
)
7777

@@ -174,6 +174,6 @@ def expand_sampling_metadata(
174174
max_num_logprobs=sampling_metadata.max_num_logprobs,
175175
# TODO(woosuk): Support penalties with spec decoding.
176176
idx_mapping=sampling_metadata.idx_mapping,
177-
prompt_bin_counts=sampling_metadata.prompt_bin_counts,
177+
prompt_bin_mask=sampling_metadata.prompt_bin_mask,
178178
output_bin_counts=sampling_metadata.output_bin_counts,
179179
)

vllm/v1/worker/gpu/sample/penalties.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def _penalties_and_temperature_kernel(
1515
presence_penalty_ptr,
1616
temperature_ptr,
1717
idx_mapping_ptr,
18-
prompt_bin_counts_ptr,
19-
prompt_bin_counts_stride,
18+
prompt_bin_mask_ptr,
19+
prompt_bin_mask_stride,
2020
output_bin_counts_ptr,
2121
output_bin_counts_stride,
2222
vocab_size,
@@ -54,13 +54,16 @@ def _penalties_and_temperature_kernel(
5454

5555
# Apply repetition penalties.
5656
if use_rep_penalty:
57-
prompt_bin_counts = tl.load(
58-
prompt_bin_counts_ptr
59-
+ req_state_idx * prompt_bin_counts_stride
60-
+ block,
61-
mask=mask,
57+
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
58+
packed_mask = tl.load(
59+
prompt_bin_mask_ptr
60+
+ req_state_idx * prompt_bin_mask_stride
61+
+ packed_block,
62+
mask=packed_block < tl.cdiv(vocab_size, 32),
6263
)
63-
prompt_bin_mask = prompt_bin_counts > 0
64+
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
65+
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
66+
6467
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
6568
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
6669
# If logits are positive, divide by penalty, otherwise multiply by penalty.
@@ -93,8 +96,8 @@ def apply_penalties_and_temperature(
9396
sampling_metadata.presence_penalty,
9497
sampling_metadata.temperature,
9598
sampling_metadata.idx_mapping,
96-
sampling_metadata.prompt_bin_counts,
97-
sampling_metadata.prompt_bin_counts.stride(0),
99+
sampling_metadata.prompt_bin_mask,
100+
sampling_metadata.prompt_bin_mask.stride(0),
98101
sampling_metadata.output_bin_counts,
99102
sampling_metadata.output_bin_counts.stride(0),
100103
vocab_size,
@@ -107,7 +110,7 @@ def _bincount_kernel(
107110
prefill_token_ids_ptr,
108111
prefill_len,
109112
prompt_len,
110-
prompt_bin_counts_ptr,
113+
prompt_bin_mask_ptr,
111114
output_bin_counts_ptr,
112115
BLOCK_SIZE: tl.constexpr,
113116
):
@@ -119,7 +122,10 @@ def _bincount_kernel(
119122
if block_idx * BLOCK_SIZE < prompt_len:
120123
mask = block < prompt_len
121124
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
122-
tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask)
125+
idx = prefill_tokens // 32
126+
bit_idx = prefill_tokens % 32
127+
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
128+
tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask)
123129
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
124130
mask = block < prefill_len
125131
mask &= block >= prompt_len
@@ -131,18 +137,18 @@ def bincount(
131137
prefill_token_ids: torch.Tensor,
132138
prefill_len: int,
133139
prompt_len: int,
134-
prompt_bin_counts: torch.Tensor,
140+
prompt_bin_mask: torch.Tensor,
135141
output_bin_counts: torch.Tensor,
136142
) -> None:
137-
prompt_bin_counts.zero_()
143+
prompt_bin_mask.zero_()
138144
output_bin_counts.zero_()
139145
BLOCK_SIZE = 1024
140146
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
141147
_bincount_kernel[(num_blocks,)](
142148
prefill_token_ids,
143149
prefill_len,
144150
prompt_len,
145-
prompt_bin_counts,
151+
prompt_bin_mask,
146152
output_bin_counts,
147153
BLOCK_SIZE=BLOCK_SIZE,
148154
)

vllm/v1/worker/gpu/states.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.lora.request import LoRARequest
99
from vllm.sampling_params import SamplingParams
10+
from vllm.utils.math_utils import cdiv
1011
from vllm.utils.platform_utils import is_uva_available
1112
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
1213
from vllm.v1.outputs import LogprobsTensors
@@ -97,11 +98,14 @@ def __init__(
9798
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
9899

99100
# Statistics for penalties.
100-
# TODO(woosuk): These tensors are rarely used but can be extremely large.
101-
# Optimize the memory usage.
102-
self.prompt_bin_counts = torch.zeros(
103-
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
101+
self.prompt_bin_mask = torch.zeros(
102+
self.max_num_reqs,
103+
cdiv(self.vocab_size, 32),
104+
dtype=torch.int32,
105+
device=self.device,
104106
)
107+
# TODO(woosuk): This tensor is rarely used but can be extremely large.
108+
# Optimize the memory usage.
105109
self.output_bin_counts = torch.zeros(
106110
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
107111
)
@@ -167,7 +171,7 @@ def add_request(
167171
self.prefill_token_ids.gpu[req_idx],
168172
prefill_len,
169173
prompt_len,
170-
self.prompt_bin_counts[req_idx],
174+
self.prompt_bin_mask[req_idx],
171175
self.output_bin_counts[req_idx],
172176
)
173177

@@ -239,7 +243,7 @@ def make_sampling_metadata(
239243
pos=pos,
240244
max_num_logprobs=max_num_logprobs,
241245
idx_mapping=idx_mapping,
242-
prompt_bin_counts=self.prompt_bin_counts,
246+
prompt_bin_mask=self.prompt_bin_mask,
243247
output_bin_counts=self.output_bin_counts,
244248
)
245249

0 commit comments

Comments
 (0)