@@ -20,6 +20,21 @@ def torch_dtype_to_tl(dtype: torch.dtype):
2020 raise ValueError (f"Unsupported dtype: { dtype } " )
2121
2222
23+ @triton .autotune (
24+ configs = [
25+ triton .Config ({"ROWS_PER_PID" : 1 , "num_stages" : 2 , "num_warps" : 1 }),
26+ triton .Config ({"ROWS_PER_PID" : 1 , "num_stages" : 2 , "num_warps" : 2 }),
27+ triton .Config ({"ROWS_PER_PID" : 2 , "num_stages" : 2 , "num_warps" : 2 }),
28+ triton .Config ({"ROWS_PER_PID" : 4 , "num_stages" : 2 , "num_warps" : 2 }),
29+ triton .Config ({"ROWS_PER_PID" : 16 , "num_stages" : 2 , "num_warps" : 4 }),
30+ triton .Config ({"ROWS_PER_PID" : 16 , "num_stages" : 3 , "num_warps" : 8 }),
31+ triton .Config ({"ROWS_PER_PID" : 32 , "num_stages" : 2 , "num_warps" : 4 }),
32+ triton .Config ({"ROWS_PER_PID" : 32 , "num_stages" : 3 , "num_warps" : 8 }),
33+ triton .Config ({"ROWS_PER_PID" : 64 , "num_stages" : 3 , "num_warps" : 8 }),
34+ triton .Config ({"ROWS_PER_PID" : 128 , "num_stages" : 3 , "num_warps" : 8 }),
35+ ],
36+ key = ["N" , "topk" , "RENORM" ],
37+ )
2338@triton .jit
2439def _topk_softmax_kernel (
2540 logits_ptr ,
@@ -37,8 +52,8 @@ def _topk_softmax_kernel(
3752 stride_ik ,
3853 BLOCK_N : tl .constexpr ,
3954 RENORM : tl .constexpr ,
40- num_stages : tl .constexpr ,
4155 ROWS_PER_PID : tl .constexpr ,
56+ num_stages : tl .constexpr ,
4257):
4358 pid = tl .program_id (0 )
4459 num_programs = tl .num_programs (0 )
@@ -48,10 +63,10 @@ def _topk_softmax_kernel(
4863 mask_n = offs_n < N
4964 store_mask = offs_k < topk
5065
51- # specify topk<=2 and RENORM specialization by tl.constexpr,
52- # similar as ` constexpr if` in C++17
66+ # impl topk<=2 and RENORM specialization by tl.constexpr,
67+ # same as constexpr if in C++17
5368 if topk == 1 :
54- for row_idx in tl .range (pid , M , num_programs , num_stages ):
69+ for row_idx in tl .range (pid , M , num_programs , num_stages , warp_specialize = True ):
5570 logits = tl .load (
5671 logits_ptr + row_idx * stride_lm + offs_n * stride_ln ,
5772 mask = mask_n ,
@@ -71,10 +86,10 @@ def _topk_softmax_kernel(
7186 cur_max = 1
7287
7388 tl .store (weights_ptr + row_idx * stride_wm + 0 * stride_wk , cur_max )
74- tl .store (indices_ptr + row_idx * stride_im + 0 * stride_wk , cur_idx )
89+ tl .store (indices_ptr + row_idx * stride_im + 0 * stride_ik , cur_idx )
7590
7691 elif topk == 2 :
77- for row_idx in tl .range (pid , M , num_programs , num_stages ):
92+ for row_idx in tl .range (pid , M , num_programs , num_stages , warp_specialize = True ):
7893 logits = tl .load (
7994 logits_ptr + row_idx * stride_lm + offs_n * stride_ln ,
8095 mask = mask_n ,
@@ -103,7 +118,7 @@ def _topk_softmax_kernel(
103118 tl .store (weights_ptr + row_idx * stride_wm , val0 )
104119 tl .store (indices_ptr + row_idx * stride_im , idx0 )
105120 tl .store (weights_ptr + row_idx * stride_wm + 1 * stride_wk , val1 )
106- tl .store (indices_ptr + row_idx * stride_im + 1 * stride_wk , idx1 )
121+ tl .store (indices_ptr + row_idx * stride_im + 1 * stride_ik , idx1 )
107122
108123 else :
109124 topk_vals = tl .zeros ([ROWS_PER_PID , topk_padded ], dtype = tl .float32 ) + float (
@@ -113,7 +128,11 @@ def _topk_softmax_kernel(
113128
114129 rows = tl .arange (0 , ROWS_PER_PID )
115130 for row_idx in tl .range (
116- pid * ROWS_PER_PID , M , num_programs * ROWS_PER_PID , num_stages
131+ pid * ROWS_PER_PID ,
132+ M ,
133+ num_programs * ROWS_PER_PID ,
134+ num_stages ,
135+ warp_specialize = True ,
117136 ):
118137 row_indices = row_idx + rows # [ROWS_PER_POD,]
119138 row_mask = row_indices < M
@@ -183,16 +202,15 @@ def fused_topk_softmax(
183202) -> tuple [torch .Tensor , torch .Tensor ]:
184203 M , N = router_logits .shape # num_tokens, num_experts
185204
186- weights = torch .empty (
187- (M , topk ), device = router_logits .device , dtype = router_logits .dtype
188- )
205+ weights = torch .empty ((M , topk ), device = router_logits .device , dtype = torch .float32 )
189206 indices = torch .empty ((M , topk ), device = router_logits .device , dtype = torch .int32 )
190207
191208 BLOCK_N = triton .next_power_of_2 (N ) # num_padded_experts
192209 topk_padded = triton .next_power_of_2 (topk )
193- grid = (M ,)
194- num_stages = 2
195- ROWS_PER_PID = 4
210+
211+ # enable autotune to find correct num threadblock,
212+ # refer to https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
213+ grid = lambda META : (triton .cdiv (M , META ["ROWS_PER_PID" ]),)
196214
197215 _topk_softmax_kernel [grid ](
198216 logits_ptr = router_logits ,
@@ -210,8 +228,6 @@ def fused_topk_softmax(
210228 stride_ik = indices .stride (1 ),
211229 BLOCK_N = BLOCK_N ,
212230 RENORM = renormalize ,
213- num_stages = num_stages ,
214- ROWS_PER_PID = ROWS_PER_PID ,
215231 )
216232
217233 return weights , indices
0 commit comments