Skip to content

Commit 4b40924

Browse files
[ROCm] Fallback pytorch GELU with tanh approximation to GELU() (#29244)
Signed-off-by: Divakar Verma <[email protected]> Signed-off-by: Divakar Verma <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent c0dfc89 commit 4b40924

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

vllm/model_executor/layers/activation.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@ def __init__(self, activation_sparsity: float, approximate: str = "none"):
159159
self.approximate = approximate
160160
if approximate not in ("none", "tanh"):
161161
raise ValueError(f"Unknown approximate mode: {approximate}")
162+
if current_platform.is_rocm() and approximate == "tanh":
163+
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
164+
logger.warning_once(
165+
"[ROCm] Pytorch's native GELU with tanh approximation is currently "
166+
"unstable and produces garbage. Fallback to 'none' approximation."
167+
)
168+
self.approximate = "none"
162169

163170
# Sparsity.
164171
if activation_sparsity == 0.0:
@@ -209,6 +216,12 @@ def __init__(self, approximate: str = "none"):
209216
self.op = torch.ops._C.gelu_and_mul
210217
elif approximate == "tanh":
211218
self.op = torch.ops._C.gelu_tanh_and_mul
219+
if current_platform.is_rocm() and approximate == "tanh":
220+
logger.warning_once(
221+
"[ROCm] PyTorch's native GELU with tanh approximation is unstable "
222+
"with torch.compile. For native implementation, fallback to 'none' "
223+
"approximation. The custom kernel implementation is unaffected."
224+
)
212225
elif current_platform.is_xpu():
213226
from vllm._ipex_ops import ipex_ops
214227

@@ -219,8 +232,12 @@ def __init__(self, approximate: str = "none"):
219232

220233
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
221234
"""PyTorch-native implementation equivalent to forward()."""
235+
# TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
236+
approximate = self.approximate
237+
if current_platform.is_rocm() and approximate == "tanh":
238+
approximate = "none"
222239
d = x.shape[-1] // 2
223-
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
240+
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
224241

225242
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
226243
d = x.shape[-1] // 2
@@ -522,7 +539,16 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
522539
"gelu": lambda: nn.GELU(),
523540
"gelu_fast": lambda: FastGELU(),
524541
"gelu_new": lambda: NewGELU(),
525-
"gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"),
542+
"gelu_pytorch_tanh": lambda: (
543+
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
544+
logger.warning_once(
545+
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
546+
"Falling back to GELU(approximate='none')."
547+
),
548+
nn.GELU(approximate="none"),
549+
)[1]
550+
if current_platform.is_rocm()
551+
else nn.GELU(approximate="tanh"),
526552
"relu": lambda: nn.ReLU(),
527553
"relu2": lambda: ReLUSquaredActivation(),
528554
"silu": lambda: nn.SiLU(),

0 commit comments

Comments
 (0)