Skip to content

Commit dc1dddd

Browse files
committed
Add CI to run ty on jax/_src/dtypes.py
1 parent d9ff025 commit dc1dddd

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

.github/workflows/ci-build.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,16 @@ jobs:
189189
JAX_PLATFORM_NAME: cpu
190190
- name: Run GPU tests
191191
run: python -m pytest examples/ffi/tests
192+
193+
ty_typecheck:
194+
name: TypeCheck with ty
195+
runs-on: ubuntu-latest
196+
steps:
197+
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
198+
- uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
199+
with:
200+
python-version: '3.14'
201+
- run: |
202+
pip install uv~=0.5.30
203+
uv pip install --system ty
204+
- run: uv run ty check .

jax/_src/dtypes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from __future__ import annotations
2323

2424
import abc
25+
import builtins
2526
import dataclasses
2627
import functools
2728
import types
@@ -86,7 +87,7 @@ class ExtendedDType(StrictABC):
8687
"""Abstract Base Class for extended dtypes"""
8788
@property
8889
@abc.abstractmethod
89-
def type(self) -> type: ...
90+
def type(self) -> builtins.type: ...
9091

9192

9293
# fp8 support
@@ -634,7 +635,7 @@ def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[str | DTypeLike, ...
634635
True or False
635636
"""
636637
the_dtype = np.dtype(dtype)
637-
kind_tuple: tuple[str | DTypeLike, ...] = (
638+
kind_tuple: tuple[str | DTypeLike, ...] = ( # ty: ignore[invalid-assignment]
638639
kind if isinstance(kind, tuple) else (kind,)
639640
)
640641
options: set[DType] = set()
@@ -914,7 +915,7 @@ def check_valid_dtype(dtype: DType) -> None:
914915
raise TypeError(f"Dtype {dtype} is not a valid JAX array "
915916
"type. Only arrays of numeric types are supported by JAX.")
916917

917-
def _maybe_canonicalize_explicit_dtype(dtype: DType, fun_name: str) -> DType:
918+
def _maybe_canonicalize_explicit_dtype(dtype: DType, fun_name: str | None) -> DType:
918919
"Canonicalizes explicitly requested dtypes, per explicit_x64_dtypes."
919920
allow = config.explicit_x64_dtypes.value
920921
if allow == config.ExplicitX64Mode.ALLOW or config.enable_x64.value:

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,6 @@ max-complexity = 18
159159
"docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb" = ["F811"]
160160
"docs/jep/9407-type-promotion.ipynb" = ["F811"]
161161
"docs/autodidax.ipynb" = ["F811"]
162+
163+
[tool.ty.src]
164+
include = ["jax/_src/dtypes.py"]

0 commit comments

Comments
 (0)