@@ -159,6 +159,13 @@ def __init__(self, activation_sparsity: float, approximate: str = "none"):
159159 self .approximate = approximate
160160 if approximate not in ("none" , "tanh" ):
161161 raise ValueError (f"Unknown approximate mode: { approximate } " )
162+ if current_platform .is_rocm () and approximate == "tanh" :
163+ # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
164+ logger .warning_once (
165+ "[ROCm] Pytorch's native GELU with tanh approximation is currently "
166+ "unstable and produces garbage. Fallback to 'none' approximation."
167+ )
168+ self .approximate = "none"
162169
163170 # Sparsity.
164171 if activation_sparsity == 0.0 :
@@ -209,6 +216,12 @@ def __init__(self, approximate: str = "none"):
209216 self .op = torch .ops ._C .gelu_and_mul
210217 elif approximate == "tanh" :
211218 self .op = torch .ops ._C .gelu_tanh_and_mul
219+ if current_platform .is_rocm () and approximate == "tanh" :
220+ logger .warning_once (
221+ "[ROCm] PyTorch's native GELU with tanh approximation is unstable "
222+ "with torch.compile. For native implementation, fallback to 'none' "
223+ "approximation. The custom kernel implementation is unaffected."
224+ )
212225 elif current_platform .is_xpu ():
213226 from vllm ._ipex_ops import ipex_ops
214227
@@ -219,8 +232,12 @@ def __init__(self, approximate: str = "none"):
219232
220233 def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
221234 """PyTorch-native implementation equivalent to forward()."""
235+ # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
236+ approximate = self .approximate
237+ if current_platform .is_rocm () and approximate == "tanh" :
238+ approximate = "none"
222239 d = x .shape [- 1 ] // 2
223- return F .gelu (x [..., :d ], approximate = self . approximate ) * x [..., d :]
240+ return F .gelu (x [..., :d ], approximate = approximate ) * x [..., d :]
224241
225242 def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
226243 d = x .shape [- 1 ] // 2
@@ -522,7 +539,16 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
522539 "gelu" : lambda : nn .GELU (),
523540 "gelu_fast" : lambda : FastGELU (),
524541 "gelu_new" : lambda : NewGELU (),
525- "gelu_pytorch_tanh" : lambda : nn .GELU (approximate = "tanh" ),
542+ "gelu_pytorch_tanh" : lambda : (
543+ # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
544+ logger .warning_once (
545+ "[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
546+ "Falling back to GELU(approximate='none')."
547+ ),
548+ nn .GELU (approximate = "none" ),
549+ )[1 ]
550+ if current_platform .is_rocm ()
551+ else nn .GELU (approximate = "tanh" ),
526552 "relu" : lambda : nn .ReLU (),
527553 "relu2" : lambda : ReLUSquaredActivation (),
528554 "silu" : lambda : nn .SiLU (),
0 commit comments