Skip to content

Commit 3b5c07c

Browse files
Merge pull request #33688 from jakevdp:fix-dep-dtype
PiperOrigin-RevId: 839817597
2 parents 4fad55b + dd01b37 commit 3b5c07c

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

jax/_src/lax/lax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3485,7 +3485,7 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
34853485
def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
34863486
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
34873487
offset = asarray(core.dimension_as_value(offset))
3488-
if not dtypes.issubdtype(offset, np.integer):
3488+
if not dtypes.issubdtype(offset.dtype, np.integer):
34893489
raise TypeError(f"offset must be an integer, got {offset!r}")
34903490
shape_dtype = lax_utils.int_dtype_for_shape(shape, signed=True)
34913491
if (

jax/_src/numpy/reductions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _cast_to_numeric(operand: Array) -> Array:
202202
return promote_dtypes_numeric(operand)[0]
203203

204204
def _require_integer(arr: Array) -> Array:
205-
if not dtypes.isdtype(arr, ("bool", "integral")):
205+
if not dtypes.isdtype(arr.dtype, ("bool", "integral")):
206206
raise ValueError(f"integer argument required; got dtype={arr.dtype}")
207207
return arr
208208

tests/api_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4174,7 +4174,7 @@ def __jax_array__(self):
41744174

41754175
x = jnp.array(1)
41764176
a = AlexArray(x)
4177-
for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]:
4177+
for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.result_type]:
41784178
self.assertEqual(f(x), f(a))
41794179

41804180
x = AlexArray(jnp.array(1))

0 commit comments

Comments
 (0)