2222#include < ATen/cuda/CUDAContext.h>
2323#include < c10/cuda/CUDAGuard.h>
2424#include < c10/cuda/CUDAStream.h>
25+ #include " cutlass_extensions/common.hpp"
2526
2627#include " cute/tensor.hpp"
2728#include " cutlass/tensor_ref.h"
@@ -173,7 +174,7 @@ void run_get_group_gemm_starts(
173174}
174175
175176template <typename OutType>
176- void run_fp4_blockwise_scaled_group_mm (
177+ void run_fp4_blockwise_scaled_group_mm_sm100 (
177178 torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
178179 const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
179180 const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
@@ -343,17 +344,225 @@ void run_fp4_blockwise_scaled_group_mm(
343344
344345 auto can_implement_status = gemm_op.can_implement (args);
345346 TORCH_CHECK (can_implement_status == cutlass::Status::kSuccess ,
346- " Failed to implement GEMM" );
347+ " Failed to implement GEMM: status=" , (int )can_implement_status);
348+
349+ // Run the GEMM
350+ auto status = gemm_op.initialize (args, workspace.data_ptr ());
351+ TORCH_CHECK (status == cutlass::Status::kSuccess ,
352+ " Failed to initialize GEMM: status=" , (int )status,
353+ " workspace_size=" , workspace_size, " num_experts=" , num_experts,
354+ " M=" , M, " N=" , N, " K=" , K);
355+
356+ status = gemm_op.run (args, workspace.data_ptr (), stream);
357+ TORCH_CHECK (status == cutlass::Status::kSuccess , " Failed to run GEMM" );
358+ }
359+
360+ void run_fp4_blockwise_scaled_group_mm_sm120 (
361+ torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
362+ const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
363+ const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
364+ const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
365+ int N, int K) {
366+ using ProblemShape =
367+ cutlass::gemm::GroupProblemShape<Shape<int32_t , int32_t , int32_t >>;
368+ using ElementType = cutlass::float_e2m1_t ;
369+ using ElementSFType = cutlass::float_ue4m3_t ;
370+ using ElementA = cutlass::nv_float4_t <cutlass::float_e2m1_t >;
371+ using ElementB = cutlass::nv_float4_t <cutlass::float_e2m1_t >;
372+
373+ // NOTE: For SM120 it seems templating the output type is not supported and
374+ // we need to hardcode the output type to bfloat16
375+ using ElementC = cutlass::bfloat16_t ;
376+ using ElementD = ElementC;
377+ using ElementAccumulator = float ;
378+ // Layout definitions
379+ using LayoutA = cutlass::layout::RowMajor;
380+ using LayoutB = cutlass::layout::ColumnMajor;
381+ using LayoutC = cutlass::layout::RowMajor;
382+ using LayoutD = LayoutC;
383+
384+ // Alignment constraints
385+ static constexpr int AlignmentA = 32 ;
386+ static constexpr int AlignmentB = 32 ;
387+ static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
388+ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
389+
390+ // Architecture definitions
391+ using ArchTag = cutlass::arch::Sm120;
392+ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
393+
394+ using ClusterShape = Shape<_1, _1, _1>;
395+ using MmaTileShape = Shape<_128, _128, _128>;
396+
397+ using FusionOperation = cutlass::epilogue::fusion::LinearCombination<
398+ ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
399+
400+ using CollectiveEpilogue =
401+ typename cutlass::epilogue::collective::CollectiveBuilder<
402+ ArchTag, OperatorClass, MmaTileShape, ClusterShape,
403+ cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
404+ ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
405+ LayoutD*, AlignmentD,
406+ cutlass::epilogue::collective::EpilogueScheduleAuto,
407+ FusionOperation>::CollectiveOp;
408+
409+ using CollectiveMainloop =
410+ typename cutlass::gemm::collective::CollectiveBuilder<
411+ ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
412+ LayoutB*, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape,
413+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast <int >(
414+ sizeof (typename CollectiveEpilogue::SharedStorage))>,
415+ cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
416+
417+ using GemmKernel =
418+ cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
419+ CollectiveEpilogue>;
420+
421+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
422+ using StrideA = typename Gemm::GemmKernel::InternalStrideA;
423+ using StrideB = typename Gemm::GemmKernel::InternalStrideB;
424+ using StrideC = typename Gemm::GemmKernel::InternalStrideC;
425+ using StrideD = typename Gemm::GemmKernel::InternalStrideD;
426+
427+ using LayoutSFA =
428+ typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
429+ using LayoutSFB =
430+ typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
431+ using ScaleConfig =
432+ typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
433+
434+ using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
435+ int num_experts = static_cast <int >(expert_offsets.size (0 ));
436+ auto options_int =
437+ torch::TensorOptions ().dtype (torch::kInt64 ).device (a.device ());
438+
439+ torch::Tensor a_ptrs = torch::empty (num_experts, options_int);
440+ torch::Tensor b_ptrs = torch::empty (num_experts, options_int);
441+ torch::Tensor out_ptrs = torch::empty (num_experts, options_int);
442+ torch::Tensor a_scales_ptrs = torch::empty (num_experts, options_int);
443+ torch::Tensor b_scales_ptrs = torch::empty (num_experts, options_int);
444+ torch::Tensor alpha_ptrs = torch::empty (num_experts, options_int);
445+ torch::Tensor layout_sfa = torch::empty ({num_experts, 5 }, options_int);
446+ torch::Tensor layout_sfb = torch::empty ({num_experts, 5 }, options_int);
447+ torch::Tensor c_strides1 =
448+ torch::full ({num_experts}, output.stride (0 ), options_int);
449+ torch::Tensor a_strides1 =
450+ torch::full ({num_experts}, a.stride (0 ) * 2 , options_int);
451+ torch::Tensor b_strides1 =
452+ torch::full ({num_experts}, b.stride (1 ) * 2 , options_int);
453+
454+ run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
455+ a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
456+ layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
457+ expert_offsets, sf_offsets, problem_sizes, M, N, K);
458+
459+ // Create an instance of the GEMM
460+ Gemm gemm_op;
461+
462+ // Initialize problem_sizes_as_shapes correctly
463+ UnderlyingProblemShape* problem_sizes_as_shapes =
464+ static_cast <UnderlyingProblemShape*>(problem_sizes.data_ptr ());
465+
466+ // Set the Scheduler info
467+ cutlass::KernelHardwareInfo hw_info;
468+ using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
469+ typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
470+ scheduler.raster_order = RasterOrderOptions::AlongM;
471+ hw_info.device_id = a.get_device ();
472+ static std::unordered_map<int , int > cached_sm_counts;
473+ if (cached_sm_counts.find (hw_info.device_id ) == cached_sm_counts.end ()) {
474+ cached_sm_counts[hw_info.device_id ] =
475+ cutlass::KernelHardwareInfo::query_device_multiprocessor_count (
476+ hw_info.device_id );
477+ }
478+ hw_info.sm_count = min (cached_sm_counts[hw_info.device_id ], INT_MAX);
479+
480+ // Mainloop Arguments
481+ typename GemmKernel::MainloopArguments mainloop_args{
482+ static_cast <const ElementType**>(a_ptrs.data_ptr ()),
483+ static_cast <StrideA*>(a_strides1.data_ptr ()),
484+ static_cast <const ElementType**>(b_ptrs.data_ptr ()),
485+ static_cast <StrideB*>(b_strides1.data_ptr ()),
486+ static_cast <const ElementSFType**>(a_scales_ptrs.data_ptr ()),
487+ reinterpret_cast <LayoutSFA*>(layout_sfa.data_ptr ()),
488+ static_cast <const ElementSFType**>(b_scales_ptrs.data_ptr ()),
489+ reinterpret_cast <LayoutSFB*>(layout_sfb.data_ptr ())};
490+
491+ // Epilogue Arguments
492+ typename GemmKernel::EpilogueArguments epilogue_args{
493+ {}, // epilogue.thread
494+ nullptr ,
495+ static_cast <StrideC*>(c_strides1.data_ptr ()),
496+ static_cast <ElementD**>(out_ptrs.data_ptr ()),
497+ static_cast <StrideC*>(c_strides1.data_ptr ())};
498+ auto & fusion_args = epilogue_args.thread ;
499+ fusion_args.alpha_ptr_array =
500+ reinterpret_cast <float **>(alpha_ptrs.data_ptr ());
501+ fusion_args.dAlpha = {_0{}, _0{}, 1 };
502+ fusion_args.beta = 0 .0f ;
503+
504+ // Gemm Arguments
505+ typename GemmKernel::Arguments args{
506+ cutlass::gemm::GemmUniversalMode::kGrouped ,
507+ {num_experts, problem_sizes_as_shapes, nullptr },
508+ mainloop_args,
509+ epilogue_args,
510+ hw_info,
511+ scheduler};
512+
513+ size_t workspace_size = Gemm::get_workspace_size (args);
514+ auto const workspace_options =
515+ torch::TensorOptions ().dtype (torch::kUInt8 ).device (a.device ());
516+ auto workspace = torch::empty (workspace_size, workspace_options);
517+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream (a.get_device ());
518+
519+ auto can_implement_status = gemm_op.can_implement (args);
520+ TORCH_CHECK (can_implement_status == cutlass::Status::kSuccess ,
521+ " Failed to implement GEMM: status=" , (int )can_implement_status);
347522
348523 // Run the GEMM
349524 auto status = gemm_op.initialize (args, workspace.data_ptr ());
350- TORCH_CHECK (status == cutlass::Status::kSuccess , " Failed to initialize GEMM" );
525+ TORCH_CHECK (status == cutlass::Status::kSuccess ,
526+ " Failed to initialize GEMM: status=" , (int )status,
527+ " workspace_size=" , workspace_size, " num_experts=" , num_experts,
528+ " M=" , M, " N=" , N, " K=" , K);
351529
352530 status = gemm_op.run (args, workspace.data_ptr (), stream);
353531 TORCH_CHECK (status == cutlass::Status::kSuccess , " Failed to run GEMM" );
354532}
355533
534+ template <typename OutType>
535+ void run_fp4_blockwise_scaled_group_mm (
536+ torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
537+ const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
538+ const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
539+ const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
540+ int N, int K) {
541+ int32_t version_num = get_sm_version_num ();
542+ #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
543+ if (version_num >= 120 && version_num < 130 ) {
544+ run_fp4_blockwise_scaled_group_mm_sm120 (
545+ output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
546+ expert_offsets, sf_offsets, M, N, K);
547+ return ;
548+ }
549+ #endif
356550#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
551+ if (version_num >= 100 && version_num < 120 ) {
552+ run_fp4_blockwise_scaled_group_mm_sm100<OutType>(
553+ output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
554+ expert_offsets, sf_offsets, M, N, K);
555+ return ;
556+ }
557+ #endif
558+ TORCH_CHECK_NOT_IMPLEMENTED (
559+ false ,
560+ " No compiled cutlass_fp4_group_mm kernel for CUDA device capability: " ,
561+ version_num, " . Required capability: 100 or 120" );
562+ }
563+
564+ #if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
565+ (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
357566constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
358567constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
359568#endif
@@ -374,7 +583,8 @@ void cutlass_fp4_group_mm(
374583 const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
375584 const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
376585 const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
377- #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
586+ #if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
587+ (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
378588 // Input validation
379589 CHECK_INPUT (a, FLOAT4_E2M1X2, " a" );
380590 CHECK_INPUT (b, FLOAT4_E2M1X2, " b" );
@@ -408,6 +618,14 @@ void cutlass_fp4_group_mm(
408618 output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
409619 expert_offsets, sf_offsets, M, N, K);
410620 } else {
621+ #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
622+ int32_t version_num = get_sm_version_num ();
623+ if (version_num >= 120 && version_num < 130 ) {
624+ TORCH_CHECK_NOT_IMPLEMENTED (
625+ false , " SM120 NVFP4 MOE only supports bfloat16 output, got: " ,
626+ output.scalar_type ());
627+ }
628+ #endif
411629 run_fp4_blockwise_scaled_group_mm<cutlass::half_t >(
412630 output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
413631 expert_offsets, sf_offsets, M, N, K);
@@ -416,8 +634,8 @@ void cutlass_fp4_group_mm(
416634 TORCH_CHECK_NOT_IMPLEMENTED (
417635 false ,
418636 " No compiled cutlass_fp4_group_mm kernel, vLLM must "
419- " be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA "
420- " 12.8 or above." );
637+ " be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
638+ " and CUDA 12.8 or above." );
421639#endif
422640}
423641
0 commit comments