|
6 | 6 | from abc import abstractmethod |
7 | 7 | from typing import Any, Callable |
8 | 8 |
|
| 9 | +from packaging.version import parse as parse_version |
| 10 | + |
9 | 11 | import awkward as ak |
10 | 12 | from awkward._nplikes.cupy import Cupy |
11 | 13 | from awkward._nplikes.jax import Jax |
@@ -102,16 +104,20 @@ def __init__(self, impl: Callable[..., Any], key: KernelKeyType): |
102 | 104 |
|
103 | 105 | self._jax = Jax.instance() |
104 | 106 |
|
| 107 | + jax_module = ak.jax.import_jax() |
| 108 | + self._ad_tracer_types = (jax_module._src.interpreters.ad.JVPTracer,) |
| 109 | + if parse_version(jax_module.__version__) >= parse_version("0.7.0"): |
| 110 | + self._ad_tracer_types += (jax_module._src.interpreters.ad.LinearizeTracer,) |
| 111 | + |
105 | 112 | def _cast(self, x, t): |
106 | 113 | if issubclass(t, ctypes._Pointer): |
107 | 114 | # Do we have a JAX-owned array? |
108 | 115 | if self._jax.is_own_array(x): |
109 | 116 | if self._jax.is_tracer_type(type(x)): |
110 | | - jax_module = ak.jax.import_jax() |
111 | 117 | # general message for any invalid JAX input type |
112 | 118 | msg = f"Encountered {x} as an (invalid) input to the '{self._key[0]}' Awkward C++ kernel." |
113 | 119 | # message specification for autodiff (i.e. when encountering a JVPTracer) |
114 | | - if isinstance(x, jax_module._src.interpreters.ad.JVPTracer): |
| 120 | + if isinstance(x, self._ad_tracer_types): |
115 | 121 | msg += " This kernel is not differentiable by the JAX backend." |
116 | 122 | raise ValueError(msg) |
117 | 123 | assert self._jax.is_c_contiguous(x), "kernel expects contiguous array" |
|
0 commit comments