@@ -22,11 +22,11 @@ def torch_dtype_to_tl(dtype: torch.dtype):
2222
2323@triton .jit
2424def _topk_softmax_kernel (
25- logits_ptr : torch . Tensor ,
26- weights_ptr : torch . Tensor ,
27- indices_ptr : torch . Tensor ,
25+ logits_ptr ,
26+ weights_ptr ,
27+ indices_ptr ,
2828 M ,
29- N ,
29+ N : tl . constexpr ,
3030 topk : tl .constexpr ,
3131 topk_padded : tl .constexpr ,
3232 stride_lm ,
@@ -36,6 +36,7 @@ def _topk_softmax_kernel(
3636 stride_im ,
3737 stride_ik ,
3838 BLOCK_N : tl .constexpr ,
39+ RENORM : tl .constexpr ,
3940 num_stages : tl .constexpr ,
4041):
4142 pid = tl .program_id (0 )
@@ -44,6 +45,7 @@ def _topk_softmax_kernel(
4445 offs_n = tl .arange (0 , BLOCK_N )
4546 offs_k = tl .arange (0 , topk_padded )
4647 mask_n = offs_n < N
48+ store_mask = offs_k < topk
4749
4850 topk_vals = tl .zeros ([topk_padded ], dtype = tl .float32 ) + float ("-inf" )
4951 topk_idxs = tl .zeros ([topk_padded ], dtype = tl .int32 )
@@ -54,10 +56,12 @@ def _topk_softmax_kernel(
5456 mask = mask_n ,
5557 other = float ("-inf" ),
5658 )
57- row_sub_max = logits - tl .max (logits , axis = 0 )
58- numerator = tl .exp (row_sub_max )
59- denominator = tl .sum (numerator , axis = 0 )
60- logits = numerator / denominator
59+
60+ if not RENORM :
61+ row_sub_max = logits - tl .max (logits , axis = 0 )
62+ numerator = tl .exp (row_sub_max )
63+ denominator = tl .sum (numerator , axis = 0 )
64+ logits = numerator / denominator
6165
6266 for k in tl .static_range (topk ):
6367 cur_max = tl .max (logits , axis = 0 )
@@ -69,7 +73,12 @@ def _topk_softmax_kernel(
6973
7074 logits = tl .where (offs_n == cur_idx , float ("-inf" ), logits )
7175
72- store_mask = offs_k < topk
76+ if RENORM :
77+ topk_vals = topk_vals - tl .max (topk_vals , axis = 0 )
78+ numerator = tl .exp (topk_vals )
79+ denominator = tl .sum (numerator , axis = 0 )
80+ topk_vals = numerator / denominator
81+
7382 tl .store (
7483 weights_ptr + row_idx * stride_wm + offs_k * stride_wk ,
7584 topk_vals ,
@@ -82,66 +91,6 @@ def _topk_softmax_kernel(
8291 )
8392
8493
85- @triton .jit
86- def _topk_softmax_renorm_kernel (
87- logits_ptr ,
88- weights_ptr ,
89- indices_ptr ,
90- M ,
91- N ,
92- topk : tl .constexpr ,
93- topk_padded : tl .constexpr ,
94- stride_lm ,
95- stride_ln ,
96- stride_wm ,
97- stride_wk ,
98- stride_im ,
99- stride_ik ,
100- BLOCK_N : tl .constexpr ,
101- num_stages : tl .constexpr ,
102- ):
103- pid = tl .program_id (0 )
104- num_programs = tl .num_programs (0 )
105-
106- offs_n = tl .arange (0 , BLOCK_N )
107- offs_k = tl .arange (0 , topk_padded )
108- mask_n = offs_n < N
109-
110- for row_idx in tl .range (pid , M , num_programs , num_stages ):
111- logits = tl .load (
112- logits_ptr + row_idx * stride_lm + offs_n * stride_ln ,
113- mask = mask_n ,
114- other = float ("-inf" ),
115- )
116-
117- topk_vals = tl .zeros ([topk_padded ], dtype = tl .float32 ) + float ("-inf" )
118- topk_idxs = tl .zeros ([topk_padded ], dtype = tl .int32 )
119-
120- running_max = float ("-inf" )
121- running_sum = 0.0
122-
123- for k in tl .static_range (topk ):
124- cur_max = tl .max (logits , axis = 0 )
125- cur_idx = tl .argmax (logits , axis = 0 )
126-
127- new_max = tl .maximum (running_max , cur_max )
128- running_sum = running_sum * tl .exp (running_max - new_max ) + tl .exp (
129- cur_max - new_max
130- )
131- running_max = new_max
132-
133- k_mask = offs_k == k
134- topk_vals = tl .where (k_mask , cur_max , topk_vals )
135- topk_idxs = tl .where (k_mask , cur_idx , topk_idxs )
136-
137- logits = tl .where (offs_n == cur_idx , float ("-inf" ), logits )
138-
139- topk_vals = tl .exp (topk_vals - running_max ) / running_sum
140-
141- tl .store (weights_ptr + row_idx * stride_wm + offs_k * stride_wk , topk_vals )
142- tl .store (indices_ptr + row_idx * stride_im + offs_k * stride_ik , topk_idxs )
143-
144-
14594def fused_topk_softmax (
14695 router_logits : torch .Tensor ,
14796 topk : int ,
@@ -155,48 +104,28 @@ def fused_topk_softmax(
155104 indices = torch .empty ((M , topk ), device = router_logits .device , dtype = torch .int32 )
156105
157106 BLOCK_N = triton .next_power_of_2 (N ) # num_padded_experts
158-
159107 topk_padded = triton .next_power_of_2 (topk )
160-
161108 grid = (M ,)
162109 num_stages = 2
163110
164- if renormalize :
165- _topk_softmax_renorm_kernel [grid ](
166- logits_ptr = router_logits ,
167- weights_ptr = weights ,
168- indices_ptr = indices ,
169- M = M ,
170- N = N ,
171- topk = topk ,
172- topk_padded = topk_padded ,
173- stride_lm = router_logits .stride (0 ),
174- stride_ln = router_logits .stride (1 ),
175- stride_wm = weights .stride (0 ),
176- stride_wk = weights .stride (1 ),
177- stride_im = indices .stride (0 ),
178- stride_ik = indices .stride (1 ),
179- BLOCK_N = BLOCK_N ,
180- num_stages = num_stages ,
181- )
182- else :
183- _topk_softmax_kernel [grid ](
184- logits_ptr = router_logits ,
185- weights_ptr = weights ,
186- indices_ptr = indices ,
187- M = M ,
188- N = N ,
189- topk = topk ,
190- topk_padded = topk_padded ,
191- stride_lm = router_logits .stride (0 ),
192- stride_ln = router_logits .stride (1 ),
193- stride_wm = weights .stride (0 ),
194- stride_wk = weights .stride (1 ),
195- stride_im = indices .stride (0 ),
196- stride_ik = indices .stride (1 ),
197- BLOCK_N = BLOCK_N ,
198- num_stages = num_stages ,
199- )
111+ _topk_softmax_kernel [grid ](
112+ logits_ptr = router_logits ,
113+ weights_ptr = weights ,
114+ indices_ptr = indices ,
115+ M = M ,
116+ N = N ,
117+ topk = topk ,
118+ topk_padded = topk_padded ,
119+ stride_lm = router_logits .stride (0 ),
120+ stride_ln = router_logits .stride (1 ),
121+ stride_wm = weights .stride (0 ),
122+ stride_wk = weights .stride (1 ),
123+ stride_im = indices .stride (0 ),
124+ stride_ik = indices .stride (1 ),
125+ BLOCK_N = BLOCK_N ,
126+ RENORM = renormalize ,
127+ num_stages = num_stages ,
128+ )
200129
201130 return weights , indices
202131
0 commit comments