@@ -47,48 +47,106 @@ def _topk_softmax_kernel(
4747 mask_n = offs_n < N
4848 store_mask = offs_k < topk
4949
50- topk_vals = tl .zeros ([topk_padded ], dtype = tl .float32 ) + float ("-inf" )
51- topk_idxs = tl .zeros ([topk_padded ], dtype = tl .int32 )
52-
53- for row_idx in tl .range (pid , M , num_programs , num_stages ):
54- logits = tl .load (
55- logits_ptr + row_idx * stride_lm + offs_n * stride_ln ,
56- mask = mask_n ,
57- other = float ("-inf" ),
58- )
59-
60- if not RENORM :
61- row_sub_max = logits - tl .max (logits , axis = 0 )
62- numerator = tl .exp (row_sub_max )
63- denominator = tl .sum (numerator , axis = 0 )
64- logits = numerator / denominator
65-
66- for k in tl .static_range (topk ):
50+ # specify topk<=2 and RENORM specialization by tl.constexpr,
51+ # similar as `constexpr if` in C++17
52+ if topk == 1 :
53+ for row_idx in tl .range (pid , M , num_programs , num_stages ):
54+ logits = tl .load (
55+ logits_ptr + row_idx * stride_lm + offs_n * stride_ln ,
56+ mask = mask_n ,
57+ other = float ("-inf" ),
58+ )
59+
60+ if not RENORM :
61+ row_sub_max = logits - tl .max (logits , axis = 0 )
62+ numerator = tl .exp (row_sub_max )
63+ denominator = tl .sum (numerator , axis = 0 )
64+ logits = numerator / denominator
65+
6766 cur_max = tl .max (logits , axis = 0 )
6867 cur_idx = tl .argmax (logits , axis = 0 )
6968
70- k_mask = offs_k == k
71- topk_vals = tl .where (k_mask , cur_max , topk_vals )
72- topk_idxs = tl .where (k_mask , cur_idx , topk_idxs )
73-
74- logits = tl .where (offs_n == cur_idx , float ("-inf" ), logits )
75-
76- if RENORM :
77- topk_vals = topk_vals - tl .max (topk_vals , axis = 0 )
78- numerator = tl .exp (topk_vals )
79- denominator = tl .sum (numerator , axis = 0 )
80- topk_vals = numerator / denominator
81-
82- tl .store (
83- weights_ptr + row_idx * stride_wm + offs_k * stride_wk ,
84- topk_vals ,
85- mask = store_mask ,
86- )
87- tl .store (
88- indices_ptr + row_idx * stride_im + offs_k * stride_ik ,
89- topk_idxs ,
90- mask = store_mask ,
91- )
69+ if RENORM :
70+ cur_max = 1
71+
72+ tl .store (weights_ptr + row_idx * stride_wm + 0 * stride_wk , cur_max )
73+ tl .store (indices_ptr + row_idx * stride_im + 0 * stride_wk , cur_idx )
74+
75+ elif topk == 2 :
76+ for row_idx in tl .range (pid , M , num_programs , num_stages ):
77+ logits = tl .load (
78+ logits_ptr + row_idx * stride_lm + offs_n * stride_ln ,
79+ mask = mask_n ,
80+ other = float ("-inf" ),
81+ )
82+
83+ if not RENORM :
84+ row_sub_max = logits - tl .max (logits , axis = 0 )
85+ numerator = tl .exp (row_sub_max )
86+ denominator = tl .sum (numerator , axis = 0 )
87+ logits = numerator / denominator
88+
89+ val0 = tl .max (logits , axis = 0 )
90+ idx0 = tl .argmax (logits , axis = 0 )
91+ logits = tl .where (offs_n == idx0 , float ("-inf" ), logits )
92+ val1 = tl .max (logits , axis = 0 )
93+ idx1 = tl .argmax (logits , axis = 0 )
94+
95+ if RENORM :
96+ max_val = tl .maximum (val0 , val1 )
97+ exp0 = tl .exp (val0 - max_val )
98+ exp1 = tl .exp (val1 - max_val )
99+ val0 = exp0 / (exp0 + exp1 )
100+ val1 = exp1 / (exp0 + exp1 )
101+
102+ tl .store (weights_ptr + row_idx * stride_wm , val0 )
103+ tl .store (indices_ptr + row_idx * stride_im , idx0 )
104+ tl .store (weights_ptr + row_idx * stride_wm + 1 * stride_wk , val1 )
105+ tl .store (indices_ptr + row_idx * stride_im + 1 * stride_wk , idx1 )
106+
107+ else :
108+ topk_vals = tl .zeros ([topk_padded ], dtype = tl .float32 ) + float ("-inf" )
109+ topk_idxs = tl .zeros ([topk_padded ], dtype = tl .int32 )
110+
111+ for row_idx in tl .range (pid , M , num_programs , num_stages ):
112+ logits = tl .load (
113+ logits_ptr + row_idx * stride_lm + offs_n * stride_ln ,
114+ mask = mask_n ,
115+ other = float ("-inf" ),
116+ )
117+
118+ if not RENORM :
119+ row_sub_max = logits - tl .max (logits , axis = 0 )
120+ numerator = tl .exp (row_sub_max )
121+ denominator = tl .sum (numerator , axis = 0 )
122+ logits = numerator / denominator
123+
124+ for k in tl .static_range (topk ):
125+ cur_max = tl .max (logits , axis = 0 )
126+ cur_idx = tl .argmax (logits , axis = 0 )
127+
128+ k_mask = offs_k == k
129+ topk_vals = tl .where (k_mask , cur_max , topk_vals )
130+ topk_idxs = tl .where (k_mask , cur_idx , topk_idxs )
131+
132+ logits = tl .where (offs_n == cur_idx , float ("-inf" ), logits )
133+
134+ if RENORM :
135+ topk_vals = topk_vals - tl .max (topk_vals , axis = 0 )
136+ numerator = tl .exp (topk_vals )
137+ denominator = tl .sum (numerator , axis = 0 )
138+ topk_vals = numerator / denominator
139+
140+ tl .store (
141+ weights_ptr + row_idx * stride_wm + offs_k * stride_wk ,
142+ topk_vals ,
143+ mask = store_mask ,
144+ )
145+ tl .store (
146+ indices_ptr + row_idx * stride_im + offs_k * stride_ik ,
147+ topk_idxs ,
148+ mask = store_mask ,
149+ )
92150
93151
94152def fused_topk_softmax (
0 commit comments