Skip to content

Commit 5c7c09a

Browse files
jthomson04khluu
authored andcommitted
[Perf] Avoid pageable HtoD transfer in MinTokensLogitsProcessor (#29826)
Signed-off-by: jthomson04 <[email protected]> (cherry picked from commit 1528e07)
1 parent 7f71816 commit 5c7c09a

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

vllm/v1/sample/logits_processor/builtin.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
110110
# Identify valid tokens using threshold comparison
111111
invalid_token_mask = probability_values < adjusted_min_p
112112
# Apply mask using boolean indexing
113-
logits[invalid_token_mask] = -float("inf")
113+
logits.masked_fill_(invalid_token_mask, -float("inf"))
114114
return logits
115115

116116

@@ -178,6 +178,10 @@ def __init__(
178178
self._device_tensor([], torch.int32),
179179
)
180180

181+
self.neg_inf_tensor = torch.tensor(
182+
-float("inf"), dtype=torch.float32, device=self.device
183+
)
184+
181185
def is_argmax_invariant(self) -> bool:
182186
"""By censoring stop tokens, min-tokens can change the outcome
183187
of the argmax operation in greedy sampling."""
@@ -229,7 +233,7 @@ def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
229233
def apply(self, logits: torch.Tensor) -> torch.Tensor:
230234
if self.min_toks:
231235
# Inhibit EOS token for requests which have not reached min length
232-
logits[self.logits_slice] = -float("inf")
236+
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
233237
return logits
234238

235239

0 commit comments

Comments
 (0)