Skip to content

Commit 4102a4a

Browse files
authored
fix: filter also for new (v0.7.0) LinearizeTracer type (#3586)
filter also for new (v0.7.0) LinearizeTracer type
1 parent ba736d9 commit 4102a4a

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/awkward/_kernels.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from abc import abstractmethod
77
from typing import Any, Callable
88

9+
from packaging.version import parse as parse_version
10+
911
import awkward as ak
1012
from awkward._nplikes.cupy import Cupy
1113
from awkward._nplikes.jax import Jax
@@ -102,16 +104,20 @@ def __init__(self, impl: Callable[..., Any], key: KernelKeyType):
102104

103105
self._jax = Jax.instance()
104106

107+
jax_module = ak.jax.import_jax()
108+
self._ad_tracer_types = (jax_module._src.interpreters.ad.JVPTracer,)
109+
if parse_version(jax_module.__version__) >= parse_version("0.7.0"):
110+
self._ad_tracer_types += (jax_module._src.interpreters.ad.LinearizeTracer,)
111+
105112
def _cast(self, x, t):
106113
if issubclass(t, ctypes._Pointer):
107114
# Do we have a JAX-owned array?
108115
if self._jax.is_own_array(x):
109116
if self._jax.is_tracer_type(type(x)):
110-
jax_module = ak.jax.import_jax()
111117
# general message for any invalid JAX input type
112118
msg = f"Encountered {x} as an (invalid) input to the '{self._key[0]}' Awkward C++ kernel."
113119
# message specification for autodiff (i.e. when encountering a JVPTracer)
114-
if isinstance(x, jax_module._src.interpreters.ad.JVPTracer):
120+
if isinstance(x, self._ad_tracer_types):
115121
msg += " This kernel is not differentiable by the JAX backend."
116122
raise ValueError(msg)
117123
assert self._jax.is_c_contiguous(x), "kernel expects contiguous array"

0 commit comments

Comments
 (0)