diff --git a/src/ATen/native/xpu/sycl/IGammaKernel.cpp b/src/ATen/native/xpu/sycl/IGammaKernel.cpp index 6ef637fc0d..d638d55781 100644 --- a/src/ATen/native/xpu/sycl/IGammaKernel.cpp +++ b/src/ATen/native/xpu/sycl/IGammaKernel.cpp @@ -9,26 +9,27 @@ namespace at::native::xpu { template struct IgammaFunctor { - IgammaFunctor(bool calc_igammac) : calc_igammac_(calc_igammac) {} - bool calc_igammac_; - [[clang::optnone]] scalar_t operator()(scalar_t a, scalar_t b) const { - if (calc_igammac_) { - return calc_igammac(a, b); - } else { - return calc_igamma(a, b); - } + scalar_t operator()(scalar_t a, scalar_t b) const { + return calc_igamma(a, b); + } +}; + +template +struct IgammacFunctor { + scalar_t operator()(scalar_t a, scalar_t b) const { + return calc_igammac(a, b); } }; void igamma_kernel(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igamma_xpu", [&]() { - gpu_kernel(iter, IgammaFunctor(false)); + gpu_kernel(iter, IgammaFunctor()); }); } void igammac_kernel(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igammac_xpu", [&]() { - gpu_kernel(iter, IgammaFunctor(true)); + gpu_kernel(iter, IgammacFunctor()); }); }