You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
By analysis the fp8_mqa_logits and fp8_paged_mqa_logits function, looks like after the q@k, we don't consider causal masking before topk?
I know the q/k feed to mqa_logits kernel is different from the q/k feed to MLA attention kernel, but we use the output of the mqa_logits kernel (and topk 2048) as indexer into the real MLA's kvcache, hence during the real attention computation we need consider causal, in prefill or decode(MTP) case.
vLLM prefill dispatch using torch.ops._C.top_k_per_row seems not considering causal, decode dispatch looks like considered causal in MTP case after the logits kernel, before topk.
not sure if it is suppose to let the framework side to do causal before topk, or actually causal is not important during the indexer kernel?