-
Notifications
You must be signed in to change notification settings - Fork 60
Description
Fray: Migrate TPU orchestration to Fray Ray backend
Part of epic: #2001
Background
Move the TPU orchestration system from lib/levanter/src/levanter/infra/ray_tpu.py into the Fray Ray backend. This is Ray-specific infrastructure for launching and managing TPU jobs with gang scheduling, multislice coordination, and retry logic.
Current Implementation
Location: lib/levanter/src/levanter/infra/ray_tpu.py (1361 lines)
Tests: lib/levanter/tests/test_ray_tpu.py (481 lines)
Key Components
Actor Hierarchy:
ResourcePoolManager[ActorInfoT]: Abstract base for managing actor poolsSlicePoolManager: Manages pools of TPU slices with flexible multislice schedulingSliceActor: Manages a single TPU slice (owns TPU-{type}-head resource)TPUHostActor: Manages a single TPU host/VM within a slice
Execution Functions:
run_on_pod(): Main entry point for running functions on TPU podsrun_on_pod_ray(): Ray remote version with retry logicrun_on_pod_multislice(): Multislice executionrun_on_pod_resumable(): Automatic retry on preemption
Supporting Infrastructure:
- TPU configuration database (v4-8, v5p-256, etc.)
- Multislice coordination via MEGASCALE environment variables
- Error classification (preemption vs failure vs system error)
- Libtpu lockfile cleanup
Dependencies
Levanter-specific (need to move):
levanter.infra.tpus.get_current_tpu_is_preempted()- TPU metadata
External:
- Ray core, actors, exceptions, scheduling
- draccus (config management)
- mergedeep (runtime env merging)
Not migrating:
levanter.infra.docker.make_docker_run_command()- Docker code stays in Levanterlevanter.utils.ray_utils.ser_exc_info()- Dead code, not used- Docker execution functions - Not used in current workflows
Proposed Migration
Directory Structure
lib/fray/src/fray/cluster/ray/tpu/
├── __init__.py
├── config.py # TPU configurations
├── orchestration.py # Actor hierarchy
├── execution.py # run_on_pod functions
└── utils.py # Error handling, metadata
lib/fray/tests/ray/
└── test_tpu.py # All TPU tests
Phase 1: Extract TPU Utilities
Create lib/fray/src/fray/cluster/ray/tpu/utils.py:
- Move
get_current_tpu_is_preempted()fromlevanter.infra.tpus - Move TPU metadata access functions
- Move error handling helpers
Phase 2: Move Core Orchestration
Create lib/fray/src/fray/cluster/ray/tpu/config.py:
- Move
TPUConfigdataclass - Move
TPU_CONFIGSlist - Move
get_tpu_config()function - Move helper functions (
_get_current_tpu_pod_type(), etc.)
Create lib/fray/src/fray/cluster/ray/tpu/orchestration.py:
- Move all actor classes:
ResourcePoolManager[ActorInfoT]SlicePoolManagerSliceActorTPUHostActor
- Move result types:
TpuSuccess,TpuPreempted,TpuFailed,TpuRunError,TpuCancelled - Move info types:
MultisliceInfo,SliceInfo,TPUHostInfo - Move helper functions:
_multislice_info_from_head(),_multislice_info_to_env_vars()
Create lib/fray/src/fray/cluster/ray/tpu/execution.py:
- Move main execution functions:
run_on_pod()run_on_pod_ray()run_on_pod_multislice()run_on_pod_resumable()run_on_pod_multislice_resumable()
- Move helper functions:
_start_fn_on_slice()_handle_ray_error()_cancel_tasks_and_wait()_stop_actor()_validate_num_slices()_hacky_remove_tpu_lockfile()
Do NOT migrate:
- Docker-related code (lines 1165-1360)
run_docker_on_pod()and related functionsRunDockerOnPodConfigsubmit_tpu_job_on_ray()ser_exc_info()from ray_utils (dead code)
Phase 3: Move Tests
Create lib/fray/tests/ray/test_tpu.py:
- Move all tests from
lib/levanter/tests/test_ray_tpu.py - Update imports to use
fray.cluster.ray.tpu.* - Keep all test utilities and fixtures
Test categories:
- Single slice tests (5 tests)
- Multislice tests (4 tests)
- Preemption and failure tests
- Test utilities:
simple_jax_fn(),CounterActor,PreemptionCountingActor
Implementation Tasks
- Create
lib/fray/src/fray/cluster/ray/tpu/config.pywith TPU configurations - Create
lib/fray/src/fray/cluster/ray/tpu/utils.pywith TPU metadata functions - Create
lib/fray/src/fray/cluster/ray/tpu/orchestration.pywith actor hierarchy - Create
lib/fray/src/fray/cluster/ray/tpu/execution.pywith run_on_pod functions - Move tests to
lib/fray/tests/ray/test_tpu.py - Add comprehensive docstrings and type hints
- Update documentation
Separate PR: Update all users and delete old code:
- Update 9 files that import from
levanter.infra.ray_tpu:lib/marin/src/marin/training/training.pylib/marin/src/marin/speedrun/speedrun.pylib/marin/src/marin/rl/rl_job.pylib/marin/src/marin/rl/evaluate_environment.pylib/marin/src/marin/evaluation/visualize.pylib/levanter/infra/launch_on_ray.py- Others in scripts/experiments
- Update imports to use
fray.cluster.ray.tpu.* - Delete
lib/levanter/src/levanter/infra/ray_tpu.py - Delete
lib/levanter/tests/test_ray_tpu.py
Testing Plan
All existing tests from test_ray_tpu.py:
test_single_slice_simple_run()test_single_slice_run_twice()test_single_slice_fail_once()test_single_slice_catches_failure()test_single_slice_handles_preemption()test_multislice_simple_run()test_variable_multislice_run()test_multislice_run_twice()test_multislice_fail_once()test_multislice_one_slice_fails()
Note: Tests require real TPUs and skip in CI
Success Criteria
- All ray_tpu orchestration code moved to
fray.cluster.ray.tpu - All tests pass at
lib/fray/tests/ray/test_tpu.py - No Docker code migrated (stays in Levanter)
- Clean module organization under Ray backend
- Documentation updated with new import paths