Skip to content

Commit ee14644

Browse files
authored
[ROCm] Aiter Quant Kernels (#25552)
Signed-off-by: vllmellm <[email protected]>
1 parent 1166c31 commit ee14644

File tree

3 files changed

+123
-2
lines changed

3 files changed

+123
-2
lines changed

vllm/_aiter_ops.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from vllm.platforms import current_platform
1010
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
1111

12+
_FP8_DTYPE = current_platform.fp8_dtype()
13+
1214

1315
def is_aiter_found() -> bool:
1416
from importlib.util import find_spec
@@ -467,6 +469,59 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
467469
return torch.empty_like(x), torch.empty_like(residual)
468470

469471

472+
def _rocm_aiter_per_tensor_quant_impl(
473+
x: torch.Tensor,
474+
quant_dtype: torch.dtype,
475+
scale: torch.Tensor | None = None,
476+
) -> tuple[torch.Tensor, torch.Tensor]:
477+
from aiter.ops.quant import per_tensor_quant_hip
478+
479+
return per_tensor_quant_hip(x, scale, quant_dtype)
480+
481+
482+
def _rocm_aiter_per_tensor_quant_fake(
483+
x: torch.Tensor,
484+
quant_dtype: torch.dtype,
485+
scale: torch.Tensor | None = None,
486+
) -> tuple[torch.Tensor, torch.Tensor]:
487+
return torch.empty_like(x, dtype=quant_dtype), torch.empty(
488+
1, dtype=torch.float32, device=x.device
489+
)
490+
491+
492+
def _rocm_aiter_per_token_quant_impl(
493+
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
494+
) -> tuple[torch.Tensor, torch.Tensor]:
495+
from aiter.ops.quant import dynamic_per_token_scaled_quant
496+
497+
assert quant_dtype in [torch.int8, _FP8_DTYPE]
498+
499+
out_shape = x.shape
500+
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
501+
if scale is None:
502+
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
503+
dynamic_per_token_scaled_quant(
504+
out,
505+
x,
506+
scale,
507+
scale_ub=None,
508+
shuffle_scale=False,
509+
num_rows=None,
510+
num_rows_factor=1,
511+
)
512+
return out, scale
513+
514+
515+
def _rocm_aiter_per_token_quant_fake(
516+
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
517+
) -> tuple[torch.Tensor, torch.Tensor]:
518+
out_shape = x.shape
519+
return (
520+
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
521+
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
522+
)
523+
524+
470525
# Global flag to ensure ops are registered only once
471526
_OPS_REGISTERED = False
472527

@@ -665,6 +720,22 @@ def register_ops_once() -> None:
665720
dispatch_key=current_platform.dispatch_key,
666721
)
667722

723+
direct_register_custom_op(
724+
op_name="rocm_aiter_per_tensor_quant",
725+
op_func=_rocm_aiter_per_tensor_quant_impl,
726+
mutates_args=[],
727+
fake_impl=_rocm_aiter_per_tensor_quant_fake,
728+
dispatch_key=current_platform.dispatch_key,
729+
)
730+
731+
direct_register_custom_op(
732+
op_name="rocm_aiter_per_token_quant",
733+
op_func=_rocm_aiter_per_token_quant_impl,
734+
mutates_args=["scale"],
735+
fake_impl=_rocm_aiter_per_token_quant_fake,
736+
dispatch_key=current_platform.dispatch_key,
737+
)
738+
668739
_OPS_REGISTERED = True
669740

670741
@staticmethod
@@ -859,6 +930,22 @@ def mla_decode_fwd(
859930
kv_scale=kv_scale,
860931
)
861932

933+
@staticmethod
934+
def per_tensor_quant(
935+
x: torch.Tensor,
936+
quant_dtype: torch.dtype,
937+
scale: torch.Tensor | None = None,
938+
) -> tuple[torch.Tensor, torch.Tensor]:
939+
return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale)
940+
941+
@staticmethod
942+
def per_token_quant(
943+
x: torch.Tensor,
944+
quant_dtype: torch.dtype,
945+
scale: torch.Tensor | None = None,
946+
) -> tuple[torch.Tensor, torch.Tensor]:
947+
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
948+
862949
@staticmethod
863950
def triton_fp4_gemm_dynamic_qaunt(
864951
x: torch.Tensor,

vllm/model_executor/layers/quantization/input_quant_fp8.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn.functional as F
66

77
from vllm import _custom_ops as ops
8+
from vllm._aiter_ops import rocm_aiter_ops
89
from vllm.model_executor.custom_op import CustomOp
910
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
1011
from vllm.platforms import current_platform
@@ -45,10 +46,13 @@ def __init__(
4546
super().__init__()
4647
self.static = static
4748
self.group_shape = group_shape
49+
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
4850
self.num_token_padding = num_token_padding
4951
self.column_major_scales = column_major_scales
5052
self.use_ue8m0 = use_ue8m0
5153

54+
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enaled()
55+
5256
self.is_group_quant = group_shape.is_per_group()
5357
if self.is_group_quant:
5458
assert not static, "Group quantization only supports dynamic mode"
@@ -92,6 +96,33 @@ def forward_cuda(
9296
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
9397
)
9498

99+
def forward_hip(
100+
self,
101+
x: torch.Tensor,
102+
scale: torch.Tensor | None = None,
103+
scale_ub: torch.Tensor | None = None,
104+
) -> tuple[torch.Tensor, torch.Tensor]:
105+
use_aiter_quant = (
106+
not self.is_group_quant
107+
and self.use_aiter
108+
and scale_ub is None
109+
and x.is_contiguous()
110+
)
111+
use_aiter_per_tensor_quant = (
112+
use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR
113+
)
114+
use_aiter_per_token_quant = (
115+
use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN
116+
)
117+
118+
if use_aiter_per_tensor_quant:
119+
return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale)
120+
if use_aiter_per_token_quant:
121+
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
122+
123+
# Fallback to CUDA implementation
124+
return self.forward_cuda(x, scale, scale_ub)
125+
95126
def forward_native(
96127
self,
97128
x: torch.Tensor,

vllm/platforms/rocm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
381381
compilation_config = vllm_config.compilation_config
382382
parallel_config = vllm_config.parallel_config
383383
is_eager_execution = compilation_config == CUDAGraphMode.NONE
384+
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
385+
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
384386

385387
if compilation_config.cudagraph_mode.has_full_cudagraphs():
386388
# decode context parallel does not support full cudagraphs
@@ -400,8 +402,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
400402
)
401403
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
402404

403-
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
404-
405405
if cache_config and cache_config.block_size is None:
406406
cache_config.block_size = 16
407407

@@ -415,6 +415,9 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
415415
):
416416
compilation_config.custom_ops.append("+rms_norm")
417417

418+
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
419+
compilation_config.custom_ops.append("+quant_fp8")
420+
418421
@classmethod
419422
def verify_model_arch(cls, model_arch: str) -> None:
420423
if model_arch in _ROCM_UNSUPPORTED_MODELS:

0 commit comments

Comments
 (0)