Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu")
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
Expand Down
206 changes: 202 additions & 4 deletions csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cutlass_extensions/common.hpp"

#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
Expand Down Expand Up @@ -173,7 +174,7 @@ void run_get_group_gemm_starts(
}

template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
Expand Down Expand Up @@ -353,7 +354,203 @@ void run_fp4_blockwise_scaled_group_mm(
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}

template <typename OutType>
void run_fp4_blockwise_scaled_group_mm_sm120(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
using ElementSFType = cutlass::float_ue4m3_t;
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;

using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;

// Alignment constraints
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;

// Architecture definitions
using ArchTag = cutlass::arch::Sm120;
// Both mainloop and epilogue use BlockScaledTensorOp on SM120
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;

using ClusterShape = Shape<_1, _1, _1>;
using MmaTileShape = Shape<_128, _128, _128>;
using PerSmTileShape_MNK = Shape<_128, _128, _128>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
LayoutD*, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
LayoutB*, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;

using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;

using LayoutSFA =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using ScaleConfig =
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;

using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());

torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor c_strides1 =
torch::full({num_experts}, output.stride(0), options_int);
torch::Tensor a_strides1 =
torch::full({num_experts}, a.stride(0) * 2, options_int);
torch::Tensor b_strides1 =
torch::full({num_experts}, b.stride(1) * 2, options_int);

run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
expert_offsets, sf_offsets, problem_sizes, M, N, K);

// Create an instance of the GEMM
Gemm gemm_op;

// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());

// Set the Scheduler info - SM120 uses auto scheduler
cutlass::KernelHardwareInfo hw_info;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
}
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);

// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementType**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides1.data_ptr()),
static_cast<const ElementType**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(b_strides1.data_ptr()),
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};

// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(c_strides1.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides1.data_ptr())};
auto& fusion_args = epilogue_args.thread;
fusion_args.alpha_ptr_array =
reinterpret_cast<float**>(alpha_ptrs.data_ptr());
fusion_args.dAlpha = {_0{}, _0{}, 1};

// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info,
scheduler};

size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());

auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");

// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");

status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}

template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
if (version_num == 120) {
run_fp4_blockwise_scaled_group_mm_sm120<OutType>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
return;
}
#endif
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
if (version_num >= 100 && version_num < 120) {
run_fp4_blockwise_scaled_group_mm_sm100<OutType>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 100 or 120");
}

#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#endif
Expand All @@ -374,7 +571,8 @@ void cutlass_fp4_group_mm(
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
// Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
Expand Down Expand Up @@ -416,8 +614,8 @@ void cutlass_fp4_group_mm(
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA "
"12.8 or above.");
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
"and CUDA 12.8 or above.");
#endif
}

Expand Down
Loading