@@ -15,8 +15,8 @@ def _penalties_and_temperature_kernel(
1515 presence_penalty_ptr ,
1616 temperature_ptr ,
1717 idx_mapping_ptr ,
18- prompt_bin_counts_ptr ,
19- prompt_bin_counts_stride ,
18+ prompt_bin_mask_ptr ,
19+ prompt_bin_mask_stride ,
2020 output_bin_counts_ptr ,
2121 output_bin_counts_stride ,
2222 vocab_size ,
@@ -54,13 +54,16 @@ def _penalties_and_temperature_kernel(
5454
5555 # Apply repetition penalties.
5656 if use_rep_penalty :
57- prompt_bin_counts = tl .load (
58- prompt_bin_counts_ptr
59- + req_state_idx * prompt_bin_counts_stride
60- + block ,
61- mask = mask ,
57+ packed_block = block_idx * BLOCK_SIZE // 32 + tl .arange (0 , BLOCK_SIZE // 32 )
58+ packed_mask = tl .load (
59+ prompt_bin_mask_ptr
60+ + req_state_idx * prompt_bin_mask_stride
61+ + packed_block ,
62+ mask = packed_block < tl .cdiv (vocab_size , 32 ),
6263 )
63- prompt_bin_mask = prompt_bin_counts > 0
64+ prompt_bin_mask = (packed_mask [:, None ] >> (tl .arange (0 , 32 )[None , :])) & 1
65+ prompt_bin_mask = prompt_bin_mask .reshape (BLOCK_SIZE )
66+
6467 # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
6568 scale = tl .where (prompt_bin_mask | output_bin_mask , rep_penalty , 1.0 )
6669 # If logits are positive, divide by penalty, otherwise multiply by penalty.
@@ -93,8 +96,8 @@ def apply_penalties_and_temperature(
9396 sampling_metadata .presence_penalty ,
9497 sampling_metadata .temperature ,
9598 sampling_metadata .idx_mapping ,
96- sampling_metadata .prompt_bin_counts ,
97- sampling_metadata .prompt_bin_counts .stride (0 ),
99+ sampling_metadata .prompt_bin_mask ,
100+ sampling_metadata .prompt_bin_mask .stride (0 ),
98101 sampling_metadata .output_bin_counts ,
99102 sampling_metadata .output_bin_counts .stride (0 ),
100103 vocab_size ,
@@ -107,7 +110,7 @@ def _bincount_kernel(
107110 prefill_token_ids_ptr ,
108111 prefill_len ,
109112 prompt_len ,
110- prompt_bin_counts_ptr ,
113+ prompt_bin_mask_ptr ,
111114 output_bin_counts_ptr ,
112115 BLOCK_SIZE : tl .constexpr ,
113116):
@@ -119,7 +122,10 @@ def _bincount_kernel(
119122 if block_idx * BLOCK_SIZE < prompt_len :
120123 mask = block < prompt_len
121124 prefill_tokens = tl .load (prefill_token_ids_ptr + block , mask = mask )
122- tl .atomic_add (prompt_bin_counts_ptr + prefill_tokens , 1 , mask = mask )
125+ idx = prefill_tokens // 32
126+ bit_idx = prefill_tokens % 32
127+ bit = tl .full ((BLOCK_SIZE ,), 1 , tl .int32 ) << bit_idx
128+ tl .atomic_or (prompt_bin_mask_ptr + idx , bit , mask = mask )
123129 if (block_idx + 1 ) * BLOCK_SIZE >= prompt_len :
124130 mask = block < prefill_len
125131 mask &= block >= prompt_len
@@ -131,18 +137,18 @@ def bincount(
131137 prefill_token_ids : torch .Tensor ,
132138 prefill_len : int ,
133139 prompt_len : int ,
134- prompt_bin_counts : torch .Tensor ,
140+ prompt_bin_mask : torch .Tensor ,
135141 output_bin_counts : torch .Tensor ,
136142) -> None :
137- prompt_bin_counts .zero_ ()
143+ prompt_bin_mask .zero_ ()
138144 output_bin_counts .zero_ ()
139145 BLOCK_SIZE = 1024
140146 num_blocks = triton .cdiv (prefill_len , BLOCK_SIZE )
141147 _bincount_kernel [(num_blocks ,)](
142148 prefill_token_ids ,
143149 prefill_len ,
144150 prompt_len ,
145- prompt_bin_counts ,
151+ prompt_bin_mask ,
146152 output_bin_counts ,
147153 BLOCK_SIZE = BLOCK_SIZE ,
148154 )
0 commit comments