Skip to content

Commit be28303

Browse files
committed
Migrate TPU orchestration from levanter to fray.
This is a lift-and-shift of the existing ray_tpu code to Fray, with minor adjustments to the Fray cluster interfaces required to integrate. This doesn't switch existing users other than Zephyr to the new interface.
1 parent 821f9d7 commit be28303

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+3685
-3002
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
name: Fray - Tests
2+
on:
3+
push:
4+
branches:
5+
- main
6+
pull_request:
7+
branches:
8+
- main
9+
paths:
10+
- lib/fray/**
11+
- .github/workflows/fray-unit-tests.yaml
12+
13+
concurrency:
14+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
15+
cancel-in-progress: true
16+
17+
jobs:
18+
cpu-test:
19+
runs-on: ubuntu-latest
20+
timeout-minutes: 10
21+
strategy:
22+
matrix:
23+
python-version: ["3.12"]
24+
steps:
25+
- name: Checkout code
26+
uses: actions/checkout@v3
27+
28+
- name: Set up Python ${{ matrix.python-version }}
29+
uses: actions/setup-python@v4
30+
with:
31+
python-version: ${{ matrix.python-version }}
32+
33+
- name: Install uv
34+
uses: astral-sh/setup-uv@v7
35+
with:
36+
enable-cache: true
37+
38+
- name: Test fray
39+
env:
40+
CI: true
41+
run: |
42+
cd lib/fray && uv run --group=fray-test --frozen pytest --durations=5 --tb=short -m 'not slow and not tpu_ci' -v -s tests/
43+
44+
tpu-test:
45+
runs-on: [tpu-ci]
46+
timeout-minutes: 10
47+
if: github.event.pull_request.head.repo.full_name == github.repository
48+
steps:
49+
- name: Checkout code
50+
uses: actions/checkout@v3
51+
with:
52+
ref: ${{ (github.event_name == 'pull_request_review' && format('refs/pull/{0}/merge', github.event.pull_request.number)) || '' }}
53+
54+
- name: Run TPU tests in Docker
55+
env:
56+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
57+
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
58+
run: |
59+
bash infra/tpu-ci/clean-tpu.sh
60+
61+
DOCKER_IMAGE="ghcr.io/marin-community/marin/tpu-ci:latest"
62+
echo "Using Docker image: $DOCKER_IMAGE"
63+
64+
# Create UV cache directory
65+
mkdir -p /tmp/uv-cache
66+
chmod 777 /tmp/uv-cache
67+
68+
docker run --rm \
69+
--device /dev/vfio:/dev/vfio \
70+
--shm-size=100g \
71+
--stop-timeout=5 \
72+
--cap-add=SYS_RESOURCE \
73+
--ulimit memlock=68719476736:68719476736 \
74+
-e HF_TOKEN \
75+
-e JAX_COORDINATOR_ADDRESS=127.0.0.1 \
76+
-e JAX_PLATFORMS=tpu,cpu \
77+
-e PJRT_DEVICE=TPU \
78+
-e TPU_MIN_LOG_LEVEL=3 \
79+
-e TPU_STDERR_LOG_LEVEL=3 \
80+
-e UV_CACHE_DIR=/tmp/uv-cache \
81+
-e WANDB_API_KEY \
82+
-e WANDB_MODE=offline \
83+
-v ${{ github.workspace }}:/workspace-src:ro \
84+
-v /tmp/uv-cache:/tmp/uv-cache:rw \
85+
-w /workspace \
86+
$DOCKER_IMAGE \
87+
bash -c "cp -a /workspace-src/. /workspace/ && cd /workspace/lib/fray && timeout --kill-after=5 --signal=TERM 590 uv run --group=fray-tpu-test pytest tests/ -v --tb=short -s --log-cli-level=INFO -m tpu_ci"

.github/workflows/haliax-run_tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Run Tests
1+
name: Haliax - Tests
22

33
on:
44
push:
@@ -15,7 +15,7 @@ on:
1515
- .github/workflows/haliax-*.yaml
1616

1717
jobs:
18-
build:
18+
cpu-test:
1919

2020
runs-on: ubuntu-latest
2121

.github/workflows/marin-unit-tests.yaml

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Marin - Run unit tests
1+
name: Marin - Tests
22
on:
33
push:
44
branches:
@@ -65,52 +65,10 @@ jobs:
6565
HF_TOKEN: ${{ secrets.HF_TOKEN }}
6666
CI: true
6767
run: |
68-
PYTHONPATH=tests:. uv run --frozen pytest -n3 --durations=5 --tb=short -m 'not slow and not tpu_ci' -v tests/
68+
PYTHONPATH=tests:. uv run --frozen pytest -n3 --dist=worksteal --durations=5 --tb=short -m 'not slow and not tpu_ci' -v tests/
6969
70-
tpu-test:
71-
runs-on: [tpu-ci]
72-
timeout-minutes: 10
73-
if: github.event.pull_request.head.repo.full_name == github.repository
74-
steps:
75-
- name: Checkout code
76-
uses: actions/checkout@v3
77-
with:
78-
ref: ${{ (github.event_name == 'pull_request_review' && format('refs/pull/{0}/merge', github.event.pull_request.number)) || '' }}
79-
80-
- name: Run TPU tests in Docker
70+
- name: Test zephyr
8171
env:
82-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
83-
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
72+
CI: true
8473
run: |
85-
bash infra/tpu-ci/clean-tpu.sh
86-
87-
DOCKER_IMAGE="ghcr.io/marin-community/marin/tpu-ci:latest"
88-
echo "Using Docker image: $DOCKER_IMAGE"
89-
90-
# Create UV cache directory
91-
mkdir -p /tmp/uv-cache
92-
chmod 777 /tmp/uv-cache
93-
94-
docker run --rm \
95-
--device /dev/vfio:/dev/vfio \
96-
--shm-size=100g \
97-
--stop-timeout=5 \
98-
--cap-add=SYS_RESOURCE \
99-
--ulimit memlock=68719476736:68719476736 \
100-
-e TPU_CI=true \
101-
-e JAX_COORDINATOR_ADDRESS=127.0.0.1 \
102-
-e START_RAY_TPU_CLUSTER=true \
103-
-e TPU_STDERR_LOG_LEVEL=3 \
104-
-e TPU_MIN_LOG_LEVEL=3 \
105-
-e PYTHONPATH=/workspace \
106-
-e JAX_PLATFORMS=tpu,cpu \
107-
-e PJRT_DEVICE=TPU \
108-
-e HF_TOKEN \
109-
-e WANDB_API_KEY \
110-
-e WANDB_MODE=offline \
111-
-e UV_CACHE_DIR=/tmp/uv-cache \
112-
-v ${{ github.workspace }}:/workspace-src:ro \
113-
-v /tmp/uv-cache:/tmp/uv-cache:rw \
114-
-w /workspace \
115-
$DOCKER_IMAGE \
116-
bash -c "cp -a /workspace-src/. /workspace/ && cd /workspace && timeout --kill-after=5 --signal=TERM 590 uv run --package marin --extra tpu --group test --frozen pytest tests/tpu -m tpu_ci -v --tb=short -s --log-cli-level=INFO"
74+
cd lib/zephyr && uv run --frozen pytest -n3 --dist=worksteal --durations=5 --tb=short -m 'not slow and not tpu_ci' -v tests/

infra/tpu-ci/vm_manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,6 @@ def debug_tpu(name: str, test_path: str, pytest_args: str, timeout: int, env_var
789789
-e TPU_CI=true \\
790790
-e JAX_COORDINATOR_ADDRESS=127.0.0.1 \\
791791
-e START_RAY_TPU_CLUSTER=true \\
792-
-e PYTHONPATH=/workspace \\
793792
-e UV_PROJECT_ENVIRONMENT=/opt/marin/.venv \\
794793
{env_var_flags} -v {remote_dir}:/workspace:rw \\
795794
--tmpfs /workspace/logs:rw \\

lib/fray/pyproject.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,17 @@ dependencies = [
1616
"pyyaml>=6.0",
1717
]
1818

19+
[project.scripts]
20+
fray = "fray.cli:main"
21+
1922
[project.optional-dependencies]
2023
ray = ["ray>=2.45"]
2124

2225
[dependency-groups]
23-
test = ["pytest>=8.3.2", "pytest-timeout"]
24-
dev = [{ include-group = "test" }]
26+
fray_test = ["pytest>=8.3.2", "pytest-timeout", "ray[default]", "numpy",]
27+
fray_tpu_test = ["jax[tpu]", { include-group = "fray_test" }]
28+
29+
dev = [{ include-group = "fray_test" }]
2530

2631
[tool.hatch.build.targets.wheel]
2732
packages = ["src/fray"]

0 commit comments

Comments
 (0)