-
Notifications
You must be signed in to change notification settings - Fork 117
Fix torch test meshes, run torch tests in CI #1289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds infrastructure for running PyTorch-dependent tests in CI and ensures proper JAX mesh context management across model roundtrip tests.
- Adds a
torchpytest marker to theskip_if_no_torchdecorator to enable selective test execution - Introduces
use_test_mesh()context manager usage across multiple model roundtrip tests - Creates a new GitHub Actions workflow for running torch-specific tests
- Removes a test that depends on unavailable models
Reviewed Changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| tests/test_utils.py | Adds torch marker to skip_if_no_torch decorator for test filtering |
| tests/test_qwen3.py | Imports and applies use_test_mesh context manager to roundtrip test |
| tests/test_qwen2.py | Imports and applies use_test_mesh context manager to roundtrip test |
| tests/test_olmo.py | Imports and applies use_test_mesh context manager to roundtrip test |
| tests/test_mixtral.py | Refactors to use use_test_mesh at tempdir level and removes redundant nested context |
| tests/test_mistral.py | Imports and applies use_test_mesh context manager to roundtrip test |
| tests/test_lora.py | Removes PEFT integration test and applies use_test_mesh to remaining tests |
| tests/test_llama3.py | Imports and applies use_test_mesh context manager to roundtrip test |
| tests/test_llama.py | Imports use_test_mesh and applies it to roundtrip test |
| tests/test_hf_gpt2_serialize.py | Wraps model loading and testing logic in use_test_mesh context |
| tests/test_gemma.py | Imports and applies use_test_mesh context manager to roundtrip tests |
| pyproject.toml | Registers torch as a pytest marker for test filtering |
| .github/workflows/run_torch_tests.yaml | New workflow for running torch-specific tests in CI |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| - name: Install dependencies | ||
| run: uv sync --dev --extra torch_test | ||
| - name: Install PyTorch CPU wheels | ||
| run: uv pip install --index-url https://download.pytorch.org/whl/cpu torch |
Copilot
AI
Nov 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Installing PyTorch after running uv sync may cause version conflicts or reinstall dependencies. The torch version constraint (>=2.7.0) in pyproject.toml torch_test extra may not be respected. Consider including torch in the uv sync command or explicitly specifying the version here to match the constraint.
| run: uv pip install --index-url https://download.pytorch.org/whl/cpu torch | |
| run: uv pip install --index-url https://download.pytorch.org/whl/cpu "torch>=2.7.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| def skip_if_no_torch(f): | ||
| return pytest.mark.skipif(not has_torch(), reason="torch not installed")(f) | ||
| return pytest.mark.torch(pytest.mark.skipif(not has_torch(), reason="torch not installed")(f)) |
Copilot
AI
Nov 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested marker application pytest.mark.torch(pytest.mark.skipif(...)(f)) is incorrect. Pytest markers should be composed using stacking, not nesting. The correct approach is to apply both markers separately by stacking them or using a marker list. This current implementation will result in pytest.mark.torch being called with a wrapped function, which is not the intended API usage.
| return pytest.mark.torch(pytest.mark.skipif(not has_torch(), reason="torch not installed")(f)) | |
| f = pytest.mark.torch(f) | |
| f = pytest.mark.skipif(not has_torch(), reason="torch not installed")(f) | |
| return f |
previously we were skipping these in CI because of slowness. But let's juts fork off yet another test to run them.