Skip to content

Commit 7dcee15

Browse files
committed
fix lora masking
Signed-off-by: gnovack <[email protected]>
1 parent 331cf23 commit 7dcee15

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,10 @@ __device__ void _moe_align_block_size_small_batch_expert(
246246
int32_t rank_post_pad =
247247
tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
248248

249-
int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
250-
sorted_token_ids[sorted_token_ids_offset + rank_post_pad] +=
251-
((i - numel) * mask);
252-
tokens_cnts[threadIdx.x * num_experts + expert_id] += mask;
249+
if (token_mask == nullptr || token_mask[i / topk_num]) {
250+
sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i;
251+
++tokens_cnts[threadIdx.x * num_experts + expert_id];
252+
}
253253
}
254254
}
255255

@@ -268,11 +268,12 @@ __device__ void _count_and_sort_expert_tokens(
268268
continue;
269269
}
270270

271-
int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
272-
int32_t rank_post_pad = atomicAdd(
273-
&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], mask);
274-
sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] +=
275-
((i - numel) * mask);
271+
if (token_mask == nullptr || token_mask[i / topk_num]) {
272+
int32_t rank_post_pad = atomicAdd(
273+
&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1);
274+
sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] =
275+
i;
276+
}
276277
}
277278
}
278279

0 commit comments

Comments
 (0)