@@ -20,6 +20,16 @@ def torch_dtype_to_tl(dtype: torch.dtype):
2020 raise ValueError (f"Unsupported dtype: { dtype } " )
2121
2222
23+ @triton .autotune (
24+ configs = [
25+ triton .Config ({"ROWS_PER_PID" : r }, num_warps = num_warps , num_stages = num_stages )
26+ for r in [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ]
27+ for num_warps in [1 , 2 , 4 , 8 , 16 , 32 ]
28+ for num_stages in [1 , 2 , 3 ]
29+ ],
30+ key = ["N" , "topk" ],
31+ cache_results = True ,
32+ )
2333@triton .jit
2434def _topk_softmax_kernel (
2535 logits_ptr ,
@@ -37,8 +47,8 @@ def _topk_softmax_kernel(
3747 stride_ik ,
3848 BLOCK_N : tl .constexpr ,
3949 RENORM : tl .constexpr ,
40- num_stages : tl .constexpr ,
4150 ROWS_PER_PID : tl .constexpr ,
51+ num_stages : tl .constexpr ,
4252):
4353 pid = tl .program_id (0 )
4454 num_programs = tl .num_programs (0 )
@@ -48,8 +58,8 @@ def _topk_softmax_kernel(
4858 mask_n = offs_n < N
4959 store_mask = offs_k < topk
5060
51- # specify topk<=2 and RENORM specialization by tl.constexpr,
52- # similar as ` constexpr if` in C++17
61+ # impl topk<=2 and RENORM specialization by tl.constexpr,
62+ # same as constexpr if in C++17
5363 if topk == 1 :
5464 for row_idx in tl .range (pid , M , num_programs , num_stages ):
5565 logits = tl .load (
@@ -64,14 +74,11 @@ def _topk_softmax_kernel(
6474 denominator = tl .sum (numerator , axis = 0 )
6575 logits = numerator / denominator
6676
67- cur_max = tl .max (logits , axis = 0 )
77+ cur_max = 1 if RENORM else tl .max (logits , axis = 0 )
6878 cur_idx = tl .argmax (logits , axis = 0 )
6979
70- if RENORM :
71- cur_max = 1
72-
7380 tl .store (weights_ptr + row_idx * stride_wm + 0 * stride_wk , cur_max )
74- tl .store (indices_ptr + row_idx * stride_im + 0 * stride_wk , cur_idx )
81+ tl .store (indices_ptr + row_idx * stride_im + 0 * stride_ik , cur_idx )
7582
7683 elif topk == 2 :
7784 for row_idx in tl .range (pid , M , num_programs , num_stages ):
@@ -103,7 +110,7 @@ def _topk_softmax_kernel(
103110 tl .store (weights_ptr + row_idx * stride_wm , val0 )
104111 tl .store (indices_ptr + row_idx * stride_im , idx0 )
105112 tl .store (weights_ptr + row_idx * stride_wm + 1 * stride_wk , val1 )
106- tl .store (indices_ptr + row_idx * stride_im + 1 * stride_wk , idx1 )
113+ tl .store (indices_ptr + row_idx * stride_im + 1 * stride_ik , idx1 )
107114
108115 else :
109116 topk_vals = tl .zeros ([ROWS_PER_PID , topk_padded ], dtype = tl .float32 ) + float (
@@ -113,7 +120,10 @@ def _topk_softmax_kernel(
113120
114121 rows = tl .arange (0 , ROWS_PER_PID )
115122 for row_idx in tl .range (
116- pid * ROWS_PER_PID , M , num_programs * ROWS_PER_PID , num_stages
123+ pid * ROWS_PER_PID ,
124+ M ,
125+ num_programs * ROWS_PER_PID ,
126+ num_stages ,
117127 ):
118128 row_indices = row_idx + rows # [ROWS_PER_POD,]
119129 row_mask = row_indices < M
@@ -183,16 +193,15 @@ def fused_topk_softmax(
183193) -> tuple [torch .Tensor , torch .Tensor ]:
184194 M , N = router_logits .shape # num_tokens, num_experts
185195
186- weights = torch .empty (
187- (M , topk ), device = router_logits .device , dtype = router_logits .dtype
188- )
196+ weights = torch .empty ((M , topk ), device = router_logits .device , dtype = torch .float32 )
189197 indices = torch .empty ((M , topk ), device = router_logits .device , dtype = torch .int32 )
190198
191199 BLOCK_N = triton .next_power_of_2 (N ) # num_padded_experts
192200 topk_padded = triton .next_power_of_2 (topk )
193- grid = (M ,)
194- num_stages = 2
195- ROWS_PER_PID = 4
201+
202+ # enable autotune to find correct num threadblock,
203+ # refer to https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
204+ grid = lambda META : (triton .cdiv (M , META ["ROWS_PER_PID" ]),)
196205
197206 _topk_softmax_kernel [grid ](
198207 logits_ptr = router_logits ,
@@ -210,8 +219,6 @@ def fused_topk_softmax(
210219 stride_ik = indices .stride (1 ),
211220 BLOCK_N = BLOCK_N ,
212221 RENORM = renormalize ,
213- num_stages = num_stages ,
214- ROWS_PER_PID = ROWS_PER_PID ,
215222 )
216223
217224 return weights , indices
0 commit comments