diff --git a/csrc/config.hpp b/csrc/config.hpp index 0e4f5b06..b56d200c 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -133,8 +133,11 @@ struct LowLatencyLayout { return reinterpret_cast(reinterpret_cast(ptr) + count); } - LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { + LowLatencyLayout( + bool disable_ll_layered, void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { const int num_scales = hidden / 128; + const int num_nodes = num_ranks / NUM_MAX_NVL_PEERS; // TODO Automatically calculate the value of NUM_MAX_NVL_PEERS according to + // the running situation of the process // Dispatch and combine layout: // - 2 symmetric odd/even send buffer @@ -145,7 +148,12 @@ struct LowLatencyLayout { // NOTES: you should add a control `int4` for combine messages if you want to do data transformation // NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); + size_t per_meta_data_size = sizeof(int4); + size_t per_token_size = std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); + if (!disable_ll_layered) { + num_bytes_per_dispatch_msg = per_meta_data_size + per_token_size; + } size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16); // Send buffer @@ -158,13 +166,23 @@ struct LowLatencyLayout { // Symmetric receive buffers // TODO: optimize memory usages size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + if (!disable_ll_layered) { + dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * per_meta_data_size + + num_nodes * num_max_dispatch_tokens_per_rank * per_token_size; // means num_experts == local_experts * num_ranks + } size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); total_bytes += recv_buffer_bytes * 2; // Symmetric signaling buffers - size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); + size_t dispatch_recv_count_buffer_bytes = + num_experts * sizeof(int); // means num_experts == local_experts * num_ranks == local_experts * NUM_MAX_NVL_PEERS * num_nodes, + // Half is used in dispatch, and the other half is used in combine. + if (!disable_ll_layered) { + dispatch_recv_count_buffer_bytes += + NUM_MAX_NVL_PEERS * num_nodes * num_max_dispatch_tokens_per_rank * sizeof(int) + NUM_MAX_NVL_PEERS * sizeof(int); + } size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes_aligned = align_up(signaling_buffer_bytes, 128); @@ -187,8 +205,10 @@ struct LowLatencyLayout { } }; -size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { - auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; +size_t get_low_latency_rdma_size_hint( + bool dispatch_ll_dispatch_opt, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { + auto num_bytes = + LowLatencyLayout(dispatch_ll_dispatch_opt, nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; } diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index ab305952..714774c8 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -132,12 +132,14 @@ Buffer::Buffer(int rank, bool low_latency_mode, bool explicitly_destroy, bool enable_shrink, - bool use_fabric) + bool use_fabric, + bool disable_ll_layered) : rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), enable_shrink(enable_shrink), + _disable_ll_layered(disable_ll_layered), low_latency_mode(low_latency_mode), explicitly_destroy(explicitly_destroy), comm_stream(at::cuda::getStreamFromPool(true)), @@ -1499,7 +1501,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); - auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + auto layout = LowLatencyLayout(_disable_ll_layered, rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); auto clean_meta_0 = layout.buffers[0].clean_meta(); auto clean_meta_1 = layout.buffers[1].clean_meta(); @@ -1571,7 +1573,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, auto num_local_experts = num_experts / num_ranks; // Buffer control - LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + LowLatencyLayout layout(_disable_ll_layered, rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; @@ -1616,6 +1618,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { internode_ll::dispatch( + _disable_ll_layered, packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr(), @@ -1729,7 +1732,7 @@ std::tuple, std::optional(topk_weights.size(0)); // Buffer control - LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + LowLatencyLayout layout(_disable_ll_layered, rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; @@ -1756,7 +1759,8 @@ std::tuple, std::optional, std::optional(m, "Buffer") - .def(pybind11::init()) + .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 604f3d9c..090e5a4f 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -70,6 +70,7 @@ struct Buffer { // Shrink mode buffer bool enable_shrink = false; + bool _disable_ll_layered = false; int* mask_buffer_ptr = nullptr; int* sync_buffer_ptr = nullptr; @@ -120,7 +121,8 @@ struct Buffer { bool low_latency_mode, bool explicitly_destroy, bool enable_shrink, - bool use_fabric); + bool use_fabric, + bool _disable_ll_layered); ~Buffer() noexcept(false); diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 95639e8e..c43dd5ec 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -282,7 +282,8 @@ void clean_low_latency_buffer(int* clean_0, int* sync_buffer, cudaStream_t stream); -void dispatch(void* packed_recv_x, +void dispatch(bool dispatch_ll_dispatch_opt, + void* packed_recv_x, void* packed_recv_x_scales, int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range, @@ -312,7 +313,8 @@ void dispatch(void* packed_recv_x, cudaStream_t stream, int phases); -void combine(void* combined_x, +void combine(bool dispatch_ll_dispatch_opt, + void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 9215b1cc..86dfaad3 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -1,7 +1,10 @@ +#include + #include "configs.cuh" #include "exception.cuh" #include "ibgda_device.cuh" #include "launch.cuh" +#include "utils.cuh" namespace deep_ep { @@ -127,7 +130,8 @@ void clean_low_latency_buffer(int* clean_0, } template -__global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, +__global__ __launch_bounds__(1024, 1) void dispatch(bool disable_ll_layered, + void* packed_recv_x, void* packed_recv_x_scales, int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range, @@ -164,6 +168,20 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, const auto sub_warp_id = warp_id % num_warps_per_group; const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + const auto num_nvl_ranks = NUM_MAX_NVL_PEERS; + const auto num_nodes = num_ranks / num_nvl_ranks; + int* data_ready_counter = reinterpret_cast(rdma_recv_count + num_experts); + int* next_clean_data_ready_counter = reinterpret_cast(next_clean + num_experts); + auto* data_ready_send_buffer = + reinterpret_cast(data_ready_counter) + num_nodes * num_max_dispatch_tokens_per_rank * num_nvl_ranks; + if (!disable_ll_layered) { + if (thread_id < num_nvl_ranks) { + st_na_global(reinterpret_cast(data_ready_send_buffer) + thread_id, 2); // set to 2 + } + __syncthreads(); + EP_DEVICE_ASSERT(num_ranks % num_nvl_ranks == 0); + } + // May extract UE8M0 from the scales using scale_t = std::conditional_t; using packed_t = std::conditional_t; @@ -177,11 +195,20 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales // NOTES: currently we have 3 reserved int fields for future use + // old code, not open dispatch opt { using vec_t = std::conditional_t; const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); - + // } open dispatch opt { + const size_t num_bytes_per_meta = sizeof(int4); + const size_t num_bytes_per_data = (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + const size_t num_bytes_per_msg_new = num_bytes_per_meta + num_bytes_per_data; + EP_DEVICE_ASSERT(num_bytes_per_msg_new % sizeof(int4) == 0); + + void* rdma_recv_x_meta = rdma_recv_x; + void* rdma_recv_x_data = (void*)(uint64_t(rdma_recv_x) + num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_meta); + //} // Expert counts constexpr int kNumMaxWarpGroups = 32; __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; @@ -202,7 +229,10 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { const auto x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; - const auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg); + auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg); + if (!disable_ll_layered) { + rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg_new); + } const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); @@ -253,29 +283,112 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, // Issue IBGDA sends if (dst_expert_idx >= 0) { + int send_node_id = dst_expert_idx >= 0 ? dst_expert_idx / num_local_experts / num_nvl_ranks : -1; int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); const auto dst_rank = dst_expert_idx / num_local_experts; const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; - const auto src_ptr = reinterpret_cast(rdma_x_src_idx); - const auto dst_ptr = reinterpret_cast(rdma_recv_x) + - dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; - const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); - if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { - if (dst_p2p_ptr == 0) { - nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); - } else { - // NOTES: only 2 load iterations for 7K hidden with 8 unrolls - const auto* src_int4_ptr = reinterpret_cast(src_ptr); - const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + auto real_write_dst_rank = dst_rank / num_nvl_ranks * num_nvl_ranks + + rank % num_nvl_ranks; // send data to same gpu_device_id_rank(same-rail rdma traffic) + auto real_dst_expert_id = real_write_dst_rank * num_local_experts + dst_expert_local_idx; + if (!disable_ll_layered) { + if (not is_rank_masked(mask_buffer_ptr, real_write_dst_rank)) { // send token + { // avoid sending repeatedly to the same node + EP_DEVICE_ASSERT(num_topk <= 32); + auto tmp_dst_expert_id = + lane_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + lane_id)) : -1; + auto tmp_dst_node_id = tmp_dst_expert_id >= 0 ? tmp_dst_expert_id / num_local_experts / num_nvl_ranks : -1; + #pragma unroll + for (int i = 0; i < warp_id; ++i) { + auto dst_node_id = __shfl_sync(0xffffffff, tmp_dst_node_id, i); // broadcast + if (dst_node_id == send_node_id) { // whether to send repeatedly + send_node_id = -1; + break; + } + } + } + + if (send_node_id != -1) { // send token + const auto src_ptr = reinterpret_cast(rdma_x_src_idx) + num_bytes_per_meta; + const auto dst_ptr = reinterpret_cast(rdma_recv_x_data) + + (rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_bytes_per_data + + token_idx * num_bytes_per_data; + const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, real_write_dst_rank); + if (dst_p2p_ptr == 0) { // one token only send once to a node + nvshmemi_ibgda_put_nbi_warp( + dst_ptr, src_ptr, num_bytes_per_data, real_write_dst_rank, dst_expert_local_idx, lane_id, slot_idx); + } else { + // NOTES: only 2 load iterations for 7K hidden with 8 unrolls + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY( + 7, lane_id, num_bytes_per_data / sizeof(int4), dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } + } + if (send_node_id != -1) { // send data ready flag + const auto src_ptr = reinterpret_cast(data_ready_send_buffer); + const auto data_ready_counter_ptr = reinterpret_cast(data_ready_counter) + + (rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_nvl_ranks * sizeof(int) + + token_idx * num_nvl_ranks * sizeof(int); + const auto data_ready_counter_p2p_ptr = nvshmemi_get_p2p_ptr(data_ready_counter_ptr, rank, real_write_dst_rank); + if (data_ready_counter_p2p_ptr == 0) { // one token only send once to a node + nvshmemi_ibgda_put_nbi_warp(data_ready_counter_ptr, + uint64_t(src_ptr), + num_nvl_ranks * sizeof(int), + real_write_dst_rank, + dst_expert_local_idx, + lane_id, + slot_idx + 1); + } else { + const auto* src_int_ptr = reinterpret_cast(src_ptr); + const auto* dst_int_ptr = reinterpret_cast(data_ready_counter_p2p_ptr); + UNROLLED_WARP_COPY(1, lane_id, num_nvl_ranks, dst_int_ptr, src_int_ptr, ld_nc_global, st_na_global); + } + } + } + // send meta + const auto src_ptr = reinterpret_cast(rdma_x_src_idx); + const auto dst_ptr = reinterpret_cast(rdma_recv_x_meta) + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_meta + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_meta + slot_idx * num_bytes_per_meta; + const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_put_nbi_warp( + dst_ptr, src_ptr, num_bytes_per_meta, dst_rank, dst_expert_local_idx, lane_id, slot_idx); + } else { + // NOTES: only 2 load iterations for 7K hidden with 8 unrolls + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY( + 1, lane_id, num_bytes_per_meta / sizeof(int4), dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } + } + } + if (disable_ll_layered) { + const auto src_ptr = reinterpret_cast(rdma_x_src_idx); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; + const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_put_nbi_warp( + dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); + } else { + // NOTES: only 2 load iterations for 7K hidden with 8 unrolls + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } } } - // Increase counter after finishing __syncwarp(); lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; + if (!disable_ll_layered) { + lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + real_dst_expert_id, 1) : 0; + } } } } else if (warp_id == num_warps - 1) { @@ -283,21 +396,23 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, if (sm_id == 0) { // The first SM is also responsible for checking QPs EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts); + if (disable_ll_layered) { + // The first SM is also responsible for cleaning the next buffer + #pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; - // The first SM is also responsible for cleaning the next buffer - #pragma unroll - for (int i = lane_id; i < num_next_clean_int; i += 32) - next_clean[i] = 0; - - // Notify before executing `int_p` - __syncwarp(); - #pragma unroll - for (int i = lane_id; i < num_experts; i += 32) - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + // Notify before executing `int_p` + __syncwarp(); + #pragma unroll + for (int i = lane_id; i < num_experts; i += 32) + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + } } // This SM should be responsible for some destination experts, read `topk_idx` for them int expert_count[kNumMaxWarpGroups] = {0}; + int waiting_flag[kNumMaxWarpGroups] = {0}; const auto expert_begin_idx = sm_id * num_warp_groups; const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); @@ -307,18 +422,61 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, auto idx = static_cast(__ldg(topk_idx + i)); if (idx >= expert_begin_idx and idx < expert_end_idx) expert_count[idx - expert_begin_idx]++; + if (!disable_ll_layered) { // only open ll dispatch opt, should do + if (idx < 0) + continue; + const auto dst_rank = idx / num_local_experts; + const auto dst_expert_local_idx = idx % num_local_experts; + auto real_write_dst_rank = dst_rank / num_nvl_ranks * num_nvl_ranks + rank % num_nvl_ranks; + auto real_dst_expert_id = real_write_dst_rank * num_local_experts + dst_expert_local_idx; + if (real_dst_expert_id >= expert_begin_idx and real_dst_expert_id < expert_end_idx) + waiting_flag[real_dst_expert_id - expert_begin_idx]++; + } } // Warp reduce #pragma unroll for (int i = expert_begin_idx; i < expert_end_idx; ++i) { auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); + auto waiting_flag_sum = 0; + if (!disable_ll_layered) { // only open ll dispatch opt, should do + waiting_flag_sum = warp_reduce_sum(waiting_flag[i - expert_begin_idx]); + } if (lane_id == 0) { shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - waiting_flag_sum - sum); + } + } + } + + if (!disable_ll_layered and sm_id == num_sms - 1) { // only open ll dispatch opt, should do + // The first SM is also responsible for cleaning the next buffer + #pragma unroll + for (int i = thread_id; i < num_experts; i += blockDim.x) // clean for combine + next_clean[i] = 0; + // clean data ready flag + #pragma unroll 8 + for (int i = thread_id; i < num_max_dispatch_tokens_per_rank * num_ranks; i += blockDim.x) { + int token_idx = i / num_ranks; + int rank_id = i % num_ranks; + { + auto node_id = rank_id / num_nvl_ranks; + auto nvl_rank_id = rank_id % num_nvl_ranks; + auto* data_ready_flag_ptr = reinterpret_cast(next_clean_data_ready_counter) + + node_id * num_max_dispatch_tokens_per_rank * num_nvl_ranks + token_idx * num_nvl_ranks + rank % num_nvl_ranks; + EP_DEVICE_ASSERT(data_ready_flag_ptr - next_clean_data_ready_counter < + num_max_dispatch_tokens_per_rank * num_nodes * num_nvl_ranks * sizeof(int)); + const auto data_ready_p2p_src_ptr = + nvshmemi_get_p2p_ptr(uint64_t(data_ready_flag_ptr), rank, rank / num_nvl_ranks * num_nvl_ranks + nvl_rank_id); + reinterpret_cast(data_ready_p2p_src_ptr)[0] = 0; } } + __syncthreads(); + #pragma unroll + for (int i = thread_id; i < num_experts; i += blockDim.x) + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); } + __syncthreads(); // Issue count sends @@ -363,9 +521,17 @@ LOW_LATENCY_DISPATCH_RECV: if (responsible_expert_idx < num_experts) { const auto src_rank = responsible_expert_idx / num_local_experts; const auto local_expert_idx = responsible_expert_idx % num_local_experts; - const auto rdma_recv_x_uint8 = static_cast(rdma_recv_x) + - local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + uint8_t* rdma_recv_x_uint8 = nullptr; + if (disable_ll_layered) { + rdma_recv_x_uint8 = static_cast(rdma_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + } + if (!disable_ll_layered) { + rdma_recv_x_uint8 = static_cast(rdma_recv_x_meta) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_meta + + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_meta; + } const auto recv_x_int4 = static_cast(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; @@ -420,19 +586,60 @@ LOW_LATENCY_DISPATCH_RECV: asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(num_warps_per_group * 32)); num_recv_tokens = shared_num_recv_tokens[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + const auto real_read_src_rank = src_rank % num_nvl_ranks + rank / num_nvl_ranks * num_nvl_ranks; // Copy tokens EP_DEVICE_ASSERT(num_scales <= 64); for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { // Copy source info - const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); - if (lane_id == 0) - recv_src_info[recv_token_begin_idx + i] = pack2(ld_nc_global(src_src_idx), src_rank); - __syncwarp(); + int4* src_data = nullptr; + if (!disable_ll_layered) { + const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_meta); + int src_token_idx = 0; + if (lane_id == 0) { + src_token_idx = ld_nc_global(src_src_idx); + recv_src_info[recv_token_begin_idx + i] = pack2(src_token_idx, src_rank); + } + src_token_idx = __shfl_sync(0xffffffff, src_token_idx, 0); + const auto data_ready_flag_src_ptr = reinterpret_cast(data_ready_counter) + + (src_rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_nvl_ranks + src_token_idx * num_nvl_ranks + + rank % num_nvl_ranks; + const auto src_data_ready_flag_p2p_ptr = + reinterpret_cast(nvshmemi_get_p2p_ptr(uint64_t(data_ready_flag_src_ptr), rank, real_read_src_rank)); + if (lane_id == 0) { + int tmp = 0; + auto start_time = clock64(); + while (tmp != 2) { // wait for data to be ready + tmp = ld_acquire_sys_global(src_data_ready_flag_p2p_ptr); + if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP ll dispatch recv data timeout,src_rank:%d, dst_rank: %d, real_read_src_rank:%d,src_token_idx:%d " + "dst RDMA lane: %d, num_recv_tokens: %d\n", + src_rank, + rank, + real_read_src_rank, + src_token_idx, + lane_id, + num_recv_tokens); + trap(); + } + } + } + __syncwarp(); + const auto src_ptr = reinterpret_cast(rdma_recv_x_data) + + (src_rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_bytes_per_data + src_token_idx * num_bytes_per_data; + src_data = reinterpret_cast(nvshmemi_get_p2p_ptr(src_ptr, rank, real_read_src_rank)); + } + if (disable_ll_layered) { + const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); + if (lane_id == 0) + recv_src_info[recv_token_begin_idx + i] = pack2(ld_nc_global(src_src_idx), src_rank); + __syncwarp(); + src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); + } // Copy data // NOTES: only 2 load iterations for 7K hidden with 7 unrolls - const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); @@ -462,7 +669,8 @@ LOW_LATENCY_DISPATCH_RECV: } } -void dispatch(void* packed_recv_x, +void dispatch(bool disable_ll_layered, + void* packed_recv_x, void* packed_recv_x_scales, int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range, @@ -519,6 +727,7 @@ void dispatch(void* packed_recv_x, dispatch_func = dispatch