@@ -38,6 +38,7 @@ def _topk_softmax_kernel(
3838 BLOCK_N : tl .constexpr ,
3939 RENORM : tl .constexpr ,
4040 num_stages : tl .constexpr ,
41+ ROWS_PER_PID : tl .constexpr ,
4142):
4243 pid = tl .program_id (0 )
4344 num_programs = tl .num_programs (0 )
@@ -105,47 +106,73 @@ def _topk_softmax_kernel(
105106 tl .store (indices_ptr + row_idx * stride_im + 1 * stride_wk , idx1 )
106107
107108 else :
108- topk_vals = tl .zeros ([topk_padded ], dtype = tl .float32 ) + float ("-inf" )
109- topk_idxs = tl .zeros ([topk_padded ], dtype = tl .int32 )
110-
111- for row_idx in tl .range (pid , M , num_programs , num_stages ):
109+ topk_vals = tl .zeros ([ROWS_PER_PID , topk_padded ], dtype = tl .float32 ) + float (
110+ "-inf"
111+ )
112+ topk_idxs = tl .zeros ([ROWS_PER_PID , topk_padded ], dtype = tl .int32 )
113+
114+ rows = tl .arange (0 , ROWS_PER_PID )
115+ for row_idx in tl .range (
116+ pid * ROWS_PER_PID , M , num_programs * ROWS_PER_PID , num_stages
117+ ):
118+ row_indices = row_idx + rows # [ROWS_PER_POD,]
119+ row_mask = row_indices < M
120+ # broadcast to [ROWS_PER_PID, BLOCKN]
112121 logits = tl .load (
113- logits_ptr + row_idx * stride_lm + offs_n * stride_ln ,
114- mask = mask_n ,
122+ logits_ptr
123+ + row_indices [:, None ] * stride_lm # [ROWS_PER_PID, 1]
124+ + offs_n [None , :] * stride_ln , # [1, BLOCK_N]
125+ mask = row_mask [:, None ] # [ROWS_PER_PID,1]
126+ & mask_n [None , :], # [1, BLOCKN]
115127 other = float ("-inf" ),
116128 )
117129
118130 if not RENORM :
119- row_sub_max = logits - tl .max (logits , axis = 0 )
131+ row_sub_max = logits - tl .max (
132+ logits , axis = 1 , keep_dims = True
133+ ) # [ROWS_PER_PID, BLOCK_N] - [ROWS_PER_PID,1]
120134 numerator = tl .exp (row_sub_max )
121- denominator = tl .sum (numerator , axis = 0 )
135+ denominator = tl .sum (
136+ numerator , axis = 1 , keep_dims = True
137+ ) # [ROWS_PER_PID, BLOCKN]
122138 logits = numerator / denominator
123139
124140 for k in tl .static_range (topk ):
125- cur_max = tl .max (logits , axis = 0 )
126- cur_idx = tl .argmax (logits , axis = 0 )
141+ cur_max = tl .max (logits , axis = 1 , keep_dims = True ) # [ROWS_PER_PID, 1]
142+ cur_idx = tl .argmax (logits , axis = 1 , keep_dims = True )
127143
128144 k_mask = offs_k == k
129- topk_vals = tl .where (k_mask , cur_max , topk_vals )
145+ topk_vals = tl .where (
146+ k_mask , cur_max , topk_vals
147+ ) # [ROWS_PER PID, 1], [ROWS_PER PID, topkpadded]
130148 topk_idxs = tl .where (k_mask , cur_idx , topk_idxs )
131149
132- logits = tl .where (offs_n == cur_idx , float ("-inf" ), logits )
150+ mask_selected = cur_idx == offs_n [None , :] # [ROWSPERPID,1] [1,BLOCKN]
151+ logits = tl .where (mask_selected , float ("-inf" ), logits )
133152
134153 if RENORM :
135- topk_vals = topk_vals - tl .max (topk_vals , axis = 0 )
154+ topk_vals = topk_vals - tl .max (
155+ topk_vals , axis = 1 , keep_dims = True
156+ ) # [ROWSPERPID, topkpadded] - [ROWSPERPID,1]
136157 numerator = tl .exp (topk_vals )
137- denominator = tl .sum (numerator , axis = 0 )
138- topk_vals = numerator / denominator
158+ denominator = tl .sum (
159+ numerator , axis = 1 , keep_dims = True
160+ ) # [ROWSPERPID,1]
161+ topk_vals = numerator / denominator # [ROWSPERPID,topkpadded]
139162
140163 tl .store (
141- weights_ptr + row_idx * stride_wm + offs_k * stride_wk ,
164+ weights_ptr
165+ + row_indices [:, None ] * stride_wm # [ROWSPERPID,1]
166+ + offs_k [None , :] * stride_wk , # [1, topkpadded]
142167 topk_vals ,
143- mask = store_mask ,
168+ mask = row_mask [:, None ] & store_mask [ None , :], # [1, topkpadded]
144169 )
145170 tl .store (
146- indices_ptr + row_idx * stride_im + offs_k * stride_ik ,
171+ indices_ptr
172+ + row_indices [:, None ] * stride_im
173+ + offs_k [None , :] * stride_ik ,
147174 topk_idxs ,
148- mask = store_mask ,
175+ mask = row_mask [:, None ] & store_mask [ None , :] ,
149176 )
150177
151178
@@ -165,6 +192,7 @@ def fused_topk_softmax(
165192 topk_padded = triton .next_power_of_2 (topk )
166193 grid = (M ,)
167194 num_stages = 2
195+ ROWS_PER_PID = 4
168196
169197 _topk_softmax_kernel [grid ](
170198 logits_ptr = router_logits ,
@@ -183,6 +211,7 @@ def fused_topk_softmax(
183211 BLOCK_N = BLOCK_N ,
184212 RENORM = renormalize ,
185213 num_stages = num_stages ,
214+ ROWS_PER_PID = ROWS_PER_PID ,
186215 )
187216
188217 return weights , indices
0 commit comments