Skip to content

Commit 3e1ad40

Browse files
authored
[Model Runner V2] Add apply_temperature option to gumbel_sample (#29276)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 62d54ba commit 3e1ad40

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

vllm/v1/worker/gpu/sampler.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ def sample(
6868

6969
sampled = gumbel_sample(
7070
logits,
71-
is_greedy,
71+
sampling_metadata.temperature,
7272
sampling_metadata.seeds,
7373
sampling_metadata.pos,
74+
apply_temperature=False,
7475
)
7576
return sampled, logits if return_logits else None
7677

@@ -85,9 +86,10 @@ def _gumbel_sample_kernel(
8586
logits_stride,
8687
seeds_ptr,
8788
pos_ptr,
88-
is_greedy_ptr,
89+
temp_ptr,
8990
vocab_size,
9091
BLOCK_SIZE: tl.constexpr,
92+
APPLY_TEMPERATURE: tl.constexpr,
9193
):
9294
req_idx = tl.program_id(0)
9395
block_idx = tl.program_id(1)
@@ -99,8 +101,8 @@ def _gumbel_sample_kernel(
99101
other=float("-inf"),
100102
)
101103

102-
is_greedy = tl.load(is_greedy_ptr + req_idx)
103-
if not is_greedy:
104+
temp = tl.load(temp_ptr + req_idx)
105+
if temp != 0.0:
104106
# Calculate the seed for gumbel noise.
105107
seed = tl.load(seeds_ptr + req_idx)
106108
pos = tl.load(pos_ptr + req_idx)
@@ -111,6 +113,11 @@ def _gumbel_sample_kernel(
111113
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
112114
gumbel_noise = gumbel_noise.to(tl.float32)
113115

116+
# Apply temperature.
117+
if APPLY_TEMPERATURE:
118+
# NOTE(woosuk): Use div_rn to match the behavior of torch.
119+
logits = tl.div_rn(logits, temp.to(tl.float32))
120+
114121
# Apply gumbel noise.
115122
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
116123

@@ -123,9 +130,10 @@ def _gumbel_sample_kernel(
123130

124131
def gumbel_sample(
125132
logits: torch.Tensor, # [num_reqs, vocab_size]
126-
is_greedy: torch.Tensor, # [num_reqs]
133+
temperature: torch.Tensor, # [num_reqs]
127134
seed: torch.Tensor, # [num_reqs]
128135
pos: torch.Tensor, # [num_reqs]
136+
apply_temperature: bool,
129137
) -> torch.Tensor:
130138
num_reqs, vocab_size = logits.shape
131139
BLOCK_SIZE = 1024
@@ -151,9 +159,10 @@ def gumbel_sample(
151159
logits.stride(0),
152160
seed,
153161
pos,
154-
is_greedy,
162+
temperature,
155163
vocab_size,
156164
BLOCK_SIZE=BLOCK_SIZE,
165+
APPLY_TEMPERATURE=apply_temperature,
157166
)
158167
# NOTE(woosuk): Use int64 for later indexing.
159168
max_block_idx = local_max.argmax(dim=-1, keepdim=True)

0 commit comments

Comments
 (0)