File tree Expand file tree Collapse file tree 1 file changed +10
-9
lines changed
Expand file tree Collapse file tree 1 file changed +10
-9
lines changed Original file line number Diff line number Diff line change @@ -246,10 +246,10 @@ __device__ void _moe_align_block_size_small_batch_expert(
246246 int32_t rank_post_pad =
247247 tokens_cnts[threadIdx .x * num_experts + expert_id] + cumsum[expert_id];
248248
249- int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
250- sorted_token_ids[sorted_token_ids_offset + rank_post_pad] +=
251- ((i - numel) * mask) ;
252- tokens_cnts[ threadIdx . x * num_experts + expert_id] += mask;
249+ if ( token_mask == nullptr || token_mask[i / topk_num]) {
250+ sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i;
251+ ++tokens_cnts[ threadIdx . x * num_experts + expert_id] ;
252+ }
253253 }
254254}
255255
@@ -268,11 +268,12 @@ __device__ void _count_and_sort_expert_tokens(
268268 continue ;
269269 }
270270
271- int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
272- int32_t rank_post_pad = atomicAdd (
273- &cumsum_buffer[(model_offset * (num_experts + 1 )) + expert_id], mask);
274- sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] +=
275- ((i - numel) * mask);
271+ if (token_mask == nullptr || token_mask[i / topk_num]) {
272+ int32_t rank_post_pad = atomicAdd (
273+ &cumsum_buffer[(model_offset * (num_experts + 1 )) + expert_id], 1 );
274+ sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] =
275+ i;
276+ }
276277 }
277278}
278279
You can’t perform that action at this time.
0 commit comments