@@ -68,9 +68,10 @@ def sample(
6868
6969 sampled = gumbel_sample (
7070 logits ,
71- is_greedy ,
71+ sampling_metadata . temperature ,
7272 sampling_metadata .seeds ,
7373 sampling_metadata .pos ,
74+ apply_temperature = False ,
7475 )
7576 return sampled , logits if return_logits else None
7677
@@ -85,9 +86,10 @@ def _gumbel_sample_kernel(
8586 logits_stride ,
8687 seeds_ptr ,
8788 pos_ptr ,
88- is_greedy_ptr ,
89+ temp_ptr ,
8990 vocab_size ,
9091 BLOCK_SIZE : tl .constexpr ,
92+ APPLY_TEMPERATURE : tl .constexpr ,
9193):
9294 req_idx = tl .program_id (0 )
9395 block_idx = tl .program_id (1 )
@@ -99,8 +101,8 @@ def _gumbel_sample_kernel(
99101 other = float ("-inf" ),
100102 )
101103
102- is_greedy = tl .load (is_greedy_ptr + req_idx )
103- if not is_greedy :
104+ temp = tl .load (temp_ptr + req_idx )
105+ if temp != 0.0 :
104106 # Calculate the seed for gumbel noise.
105107 seed = tl .load (seeds_ptr + req_idx )
106108 pos = tl .load (pos_ptr + req_idx )
@@ -111,6 +113,11 @@ def _gumbel_sample_kernel(
111113 gumbel_noise = - tl .log (- tl .log (r + 1e-20 ) + 1e-20 )
112114 gumbel_noise = gumbel_noise .to (tl .float32 )
113115
116+ # Apply temperature.
117+ if APPLY_TEMPERATURE :
118+ # NOTE(woosuk): Use div_rn to match the behavior of torch.
119+ logits = tl .div_rn (logits , temp .to (tl .float32 ))
120+
114121 # Apply gumbel noise.
115122 logits = tl .where (mask , logits + gumbel_noise , float ("-inf" ))
116123
@@ -123,9 +130,10 @@ def _gumbel_sample_kernel(
123130
124131def gumbel_sample (
125132 logits : torch .Tensor , # [num_reqs, vocab_size]
126- is_greedy : torch .Tensor , # [num_reqs]
133+ temperature : torch .Tensor , # [num_reqs]
127134 seed : torch .Tensor , # [num_reqs]
128135 pos : torch .Tensor , # [num_reqs]
136+ apply_temperature : bool ,
129137) -> torch .Tensor :
130138 num_reqs , vocab_size = logits .shape
131139 BLOCK_SIZE = 1024
@@ -151,9 +159,10 @@ def gumbel_sample(
151159 logits .stride (0 ),
152160 seed ,
153161 pos ,
154- is_greedy ,
162+ temperature ,
155163 vocab_size ,
156164 BLOCK_SIZE = BLOCK_SIZE ,
165+ APPLY_TEMPERATURE = apply_temperature ,
157166 )
158167 # NOTE(woosuk): Use int64 for later indexing.
159168 max_block_idx = local_max .argmax (dim = - 1 , keepdim = True )
0 commit comments