Skip to content

Commit 13ccd2f

Browse files
ch-wanshifangx
authored andcommitted
Fix tensor descriptions in buffer.py
1 parent b6ceecd commit 13ccd2f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

deep_ep/buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -684,9 +684,9 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
684684
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
685685
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
686686
with `use_nvfp4=True`: the first element is a `torch.Tensor` shaped as
687-
`[num_local_experts, hidden // 2, num_max_dispatch_tokens_per_rank * num_ranks]` with `torch.uint8`.
687+
`[num_max_dispatch_tokens_per_rank * num_ranks, hidden // 2, num_local_experts]` with `torch.uint8`.
688688
The second tensor is the corresponding scales for the first element with shape
689-
`[32, 4, num_max_dispatch_tokens_per_rank * num_ranks // 128, 4, hidden // 64, num_local_experts]` with `torch.uint8`.
689+
`[32, 4, num_max_dispatch_tokens_per_rank * num_ranks // 128, 4, hidden // 64, num_local_experts]` with `torch.float8_e4m3fn`.
690690
With `use_fp8=False and use_nvfp4=False`, the result would be a tensor shaped as
691691
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
692692
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,

0 commit comments

Comments
 (0)