77from vllm .triton_utils import tl , triton
88
99
10+ def torch_dtype_to_tl (dtype : torch .dtype ):
11+ if dtype == torch .float16 :
12+ return tl .float16
13+ elif dtype == torch .bfloat16 :
14+ return tl .bfloat16
15+ elif dtype == torch .float32 :
16+ return tl .float32
17+ elif dtype == torch .int32 :
18+ return tl .int32
19+ else :
20+ raise ValueError (f"Unsupported dtype: { dtype } " )
21+
22+
1023@triton .jit
1124def _topk_softmax_kernel (
12- logits_ptr ,
13- weights_ptr ,
14- indices_ptr ,
25+ logits_ptr : torch . Tensor ,
26+ weights_ptr : torch . Tensor ,
27+ indices_ptr : torch . Tensor ,
1528 M ,
1629 N ,
1730 topk : tl .constexpr ,
@@ -23,81 +36,167 @@ def _topk_softmax_kernel(
2336 stride_im ,
2437 stride_ik ,
2538 BLOCK_N : tl .constexpr ,
26- RENORM : tl .constexpr ,
39+ num_stages : tl .constexpr ,
2740):
28- token_idx = tl .program_id (0 )
41+ pid = tl .program_id (0 )
42+ num_programs = tl .num_programs (0 )
2943
30- offs = tl .arange (0 , BLOCK_N )
31- mask = offs < N
32- logit_offs = logits_ptr + token_idx * stride_lm + offs * stride_ln
33- logits = tl .load (logit_offs , mask = mask , other = float ("-inf" ))
44+ offs_n = tl .arange (0 , BLOCK_N )
45+ offs_k = tl .arange (0 , topk_padded )
46+ mask_n = offs_n < N
3447
3548 topk_vals = tl .zeros ([topk_padded ], dtype = tl .float32 ) + float ("-inf" )
3649 topk_idxs = tl .zeros ([topk_padded ], dtype = tl .int32 )
3750
38- working_logits = logits
51+ for row_idx in tl .range (pid , M , num_programs , num_stages ):
52+ logits = tl .load (
53+ logits_ptr + row_idx * stride_lm + offs_n * stride_ln ,
54+ mask = mask_n ,
55+ other = float ("-inf" ),
56+ )
57+ row_sub_max = logits - tl .max (logits , axis = 0 )
58+ numerator = tl .exp (row_sub_max )
59+ denominator = tl .sum (numerator , axis = 0 )
60+ logits = numerator / denominator
61+
62+ for k in tl .static_range (topk ):
63+ cur_max = tl .max (logits , axis = 0 )
64+ cur_idx = tl .argmax (logits , axis = 0 )
65+
66+ k_mask = offs_k == k
67+ topk_vals = tl .where (k_mask , cur_max , topk_vals )
68+ topk_idxs = tl .where (k_mask , cur_idx , topk_idxs )
69+
70+ logits = tl .where (offs_n == cur_idx , float ("-inf" ), logits )
71+
72+ store_mask = offs_k < topk
73+ tl .store (
74+ weights_ptr + row_idx * stride_wm + offs_k * stride_wk ,
75+ topk_vals ,
76+ mask = store_mask ,
77+ )
78+ tl .store (
79+ indices_ptr + row_idx * stride_im + offs_k * stride_ik ,
80+ topk_idxs ,
81+ mask = store_mask ,
82+ )
3983
40- for k in range (topk ):
41- cur_max = tl .max (working_logits , axis = 0 )
42- cur_idx = tl .argmax (working_logits , axis = 0 )
4384
44- k_mask = tl .arange (0 , topk_padded ) == k
45- topk_vals = tl .where (k_mask , cur_max , topk_vals )
46- topk_idxs = tl .where (k_mask , cur_idx , topk_idxs )
85+ @triton .jit
86+ def _topk_softmax_renorm_kernel (
87+ logits_ptr ,
88+ weights_ptr ,
89+ indices_ptr ,
90+ M ,
91+ N ,
92+ topk : tl .constexpr ,
93+ topk_padded : tl .constexpr ,
94+ stride_lm ,
95+ stride_ln ,
96+ stride_wm ,
97+ stride_wk ,
98+ stride_im ,
99+ stride_ik ,
100+ BLOCK_N : tl .constexpr ,
101+ num_stages : tl .constexpr ,
102+ ):
103+ pid = tl .program_id (0 )
104+ num_programs = tl .num_programs (0 )
47105
48- mask_selected = offs == cur_idx
49- working_logits = tl .where (mask_selected , float ("-inf" ), working_logits )
106+ offs_n = tl .arange (0 , BLOCK_N )
107+ offs_k = tl .arange (0 , topk_padded )
108+ mask_n = offs_n < N
50109
51- if RENORM :
52- max_val = tl .max (topk_vals , axis = 0 )
53- exp_vals = tl .exp (topk_vals - max_val )
54- sum_exp = tl .sum (exp_vals , axis = 0 )
55- topk_vals = exp_vals / sum_exp
110+ for row_idx in tl .range (pid , M , num_programs , num_stages ):
111+ logits = tl .load (
112+ logits_ptr + row_idx * stride_lm + offs_n * stride_ln ,
113+ mask = mask_n ,
114+ other = float ("-inf" ),
115+ )
56116
57- offs_k = tl .arange (0 , topk_padded )
117+ topk_vals = tl .zeros ([topk_padded ], dtype = tl .float32 ) + float ("-inf" )
118+ topk_idxs = tl .zeros ([topk_padded ], dtype = tl .int32 )
58119
59- store_mask = offs_k < topk
120+ running_max = float ("-inf" )
121+ running_sum = 0.0
60122
61- weight_ptrs = weights_ptr + token_idx * stride_wm + offs_k * stride_wk
62- tl .store (weight_ptrs , topk_vals , mask = store_mask )
123+ for k in tl .static_range (topk ):
124+ cur_max = tl .max (logits , axis = 0 )
125+ cur_idx = tl .argmax (logits , axis = 0 )
63126
64- index_ptrs = indices_ptr + token_idx * stride_im + offs_k * stride_ik
65- tl .store (index_ptrs , topk_idxs , mask = store_mask )
127+ new_max = tl .maximum (running_max , cur_max )
128+ running_sum = running_sum * tl .exp (running_max - new_max ) + tl .exp (
129+ cur_max - new_max
130+ )
131+ running_max = new_max
132+
133+ k_mask = offs_k == k
134+ topk_vals = tl .where (k_mask , cur_max , topk_vals )
135+ topk_idxs = tl .where (k_mask , cur_idx , topk_idxs )
136+
137+ logits = tl .where (offs_n == cur_idx , float ("-inf" ), logits )
138+
139+ topk_vals = tl .exp (topk_vals - running_max ) / running_sum
140+
141+ tl .store (weights_ptr + row_idx * stride_wm + offs_k * stride_wk , topk_vals )
142+ tl .store (indices_ptr + row_idx * stride_im + offs_k * stride_ik , topk_idxs )
66143
67144
68145def fused_topk_softmax (
69146 router_logits : torch .Tensor ,
70147 topk : int ,
71148 renormalize : bool = True ,
72149) -> tuple [torch .Tensor , torch .Tensor ]:
73- M , N = router_logits .shape
150+ M , N = router_logits .shape # num_tokens, num_experts
74151
75- weights = torch .empty ((M , topk ), device = router_logits .device , dtype = torch .float32 )
152+ weights = torch .empty (
153+ (M , topk ), device = router_logits .device , dtype = router_logits .dtype
154+ )
76155 indices = torch .empty ((M , topk ), device = router_logits .device , dtype = torch .int32 )
77156
78- BLOCK_N = triton .next_power_of_2 (N )
157+ BLOCK_N = triton .next_power_of_2 (N ) # num_padded_experts
79158
80159 topk_padded = triton .next_power_of_2 (topk )
81160
82161 grid = (M ,)
83-
84- _topk_softmax_kernel [grid ](
85- logits_ptr = router_logits ,
86- weights_ptr = weights ,
87- indices_ptr = indices ,
88- M = M ,
89- N = N ,
90- topk = topk ,
91- topk_padded = topk_padded ,
92- stride_lm = router_logits .stride (0 ),
93- stride_ln = router_logits .stride (1 ),
94- stride_wm = weights .stride (0 ),
95- stride_wk = weights .stride (1 ),
96- stride_im = indices .stride (0 ),
97- stride_ik = indices .stride (1 ),
98- BLOCK_N = BLOCK_N ,
99- RENORM = renormalize ,
100- )
162+ num_stages = 2
163+
164+ if renormalize :
165+ _topk_softmax_renorm_kernel [grid ](
166+ logits_ptr = router_logits ,
167+ weights_ptr = weights ,
168+ indices_ptr = indices ,
169+ M = M ,
170+ N = N ,
171+ topk = topk ,
172+ topk_padded = topk_padded ,
173+ stride_lm = router_logits .stride (0 ),
174+ stride_ln = router_logits .stride (1 ),
175+ stride_wm = weights .stride (0 ),
176+ stride_wk = weights .stride (1 ),
177+ stride_im = indices .stride (0 ),
178+ stride_ik = indices .stride (1 ),
179+ BLOCK_N = BLOCK_N ,
180+ num_stages = num_stages ,
181+ )
182+ else :
183+ _topk_softmax_kernel [grid ](
184+ logits_ptr = router_logits ,
185+ weights_ptr = weights ,
186+ indices_ptr = indices ,
187+ M = M ,
188+ N = N ,
189+ topk = topk ,
190+ topk_padded = topk_padded ,
191+ stride_lm = router_logits .stride (0 ),
192+ stride_ln = router_logits .stride (1 ),
193+ stride_wm = weights .stride (0 ),
194+ stride_wk = weights .stride (1 ),
195+ stride_im = indices .stride (0 ),
196+ stride_ik = indices .stride (1 ),
197+ BLOCK_N = BLOCK_N ,
198+ num_stages = num_stages ,
199+ )
101200
102201 return weights , indices
103202
0 commit comments