|
9 | 9 | from vllm.platforms import current_platform |
10 | 10 | from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer |
11 | 11 |
|
| 12 | +_FP8_DTYPE = current_platform.fp8_dtype() |
| 13 | + |
12 | 14 |
|
13 | 15 | def is_aiter_found() -> bool: |
14 | 16 | from importlib.util import find_spec |
@@ -467,6 +469,59 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( |
467 | 469 | return torch.empty_like(x), torch.empty_like(residual) |
468 | 470 |
|
469 | 471 |
|
| 472 | +def _rocm_aiter_per_tensor_quant_impl( |
| 473 | + x: torch.Tensor, |
| 474 | + quant_dtype: torch.dtype, |
| 475 | + scale: torch.Tensor | None = None, |
| 476 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 477 | + from aiter.ops.quant import per_tensor_quant_hip |
| 478 | + |
| 479 | + return per_tensor_quant_hip(x, scale, quant_dtype) |
| 480 | + |
| 481 | + |
| 482 | +def _rocm_aiter_per_tensor_quant_fake( |
| 483 | + x: torch.Tensor, |
| 484 | + quant_dtype: torch.dtype, |
| 485 | + scale: torch.Tensor | None = None, |
| 486 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 487 | + return torch.empty_like(x, dtype=quant_dtype), torch.empty( |
| 488 | + 1, dtype=torch.float32, device=x.device |
| 489 | + ) |
| 490 | + |
| 491 | + |
| 492 | +def _rocm_aiter_per_token_quant_impl( |
| 493 | + x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None |
| 494 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 495 | + from aiter.ops.quant import dynamic_per_token_scaled_quant |
| 496 | + |
| 497 | + assert quant_dtype in [torch.int8, _FP8_DTYPE] |
| 498 | + |
| 499 | + out_shape = x.shape |
| 500 | + out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device) |
| 501 | + if scale is None: |
| 502 | + scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device) |
| 503 | + dynamic_per_token_scaled_quant( |
| 504 | + out, |
| 505 | + x, |
| 506 | + scale, |
| 507 | + scale_ub=None, |
| 508 | + shuffle_scale=False, |
| 509 | + num_rows=None, |
| 510 | + num_rows_factor=1, |
| 511 | + ) |
| 512 | + return out, scale |
| 513 | + |
| 514 | + |
| 515 | +def _rocm_aiter_per_token_quant_fake( |
| 516 | + x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None |
| 517 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 518 | + out_shape = x.shape |
| 519 | + return ( |
| 520 | + torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device), |
| 521 | + torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device), |
| 522 | + ) |
| 523 | + |
| 524 | + |
470 | 525 | # Global flag to ensure ops are registered only once |
471 | 526 | _OPS_REGISTERED = False |
472 | 527 |
|
@@ -665,6 +720,22 @@ def register_ops_once() -> None: |
665 | 720 | dispatch_key=current_platform.dispatch_key, |
666 | 721 | ) |
667 | 722 |
|
| 723 | + direct_register_custom_op( |
| 724 | + op_name="rocm_aiter_per_tensor_quant", |
| 725 | + op_func=_rocm_aiter_per_tensor_quant_impl, |
| 726 | + mutates_args=[], |
| 727 | + fake_impl=_rocm_aiter_per_tensor_quant_fake, |
| 728 | + dispatch_key=current_platform.dispatch_key, |
| 729 | + ) |
| 730 | + |
| 731 | + direct_register_custom_op( |
| 732 | + op_name="rocm_aiter_per_token_quant", |
| 733 | + op_func=_rocm_aiter_per_token_quant_impl, |
| 734 | + mutates_args=["scale"], |
| 735 | + fake_impl=_rocm_aiter_per_token_quant_fake, |
| 736 | + dispatch_key=current_platform.dispatch_key, |
| 737 | + ) |
| 738 | + |
668 | 739 | _OPS_REGISTERED = True |
669 | 740 |
|
670 | 741 | @staticmethod |
@@ -859,6 +930,22 @@ def mla_decode_fwd( |
859 | 930 | kv_scale=kv_scale, |
860 | 931 | ) |
861 | 932 |
|
| 933 | + @staticmethod |
| 934 | + def per_tensor_quant( |
| 935 | + x: torch.Tensor, |
| 936 | + quant_dtype: torch.dtype, |
| 937 | + scale: torch.Tensor | None = None, |
| 938 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 939 | + return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale) |
| 940 | + |
| 941 | + @staticmethod |
| 942 | + def per_token_quant( |
| 943 | + x: torch.Tensor, |
| 944 | + quant_dtype: torch.dtype, |
| 945 | + scale: torch.Tensor | None = None, |
| 946 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 947 | + return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale) |
| 948 | + |
862 | 949 | @staticmethod |
863 | 950 | def triton_fp4_gemm_dynamic_qaunt( |
864 | 951 | x: torch.Tensor, |
|
0 commit comments