Skip to content

Commit ce969f6

Browse files
committed
Cleanups.
1 parent 8e731c1 commit ce969f6

File tree

15 files changed

+215
-100
lines changed

15 files changed

+215
-100
lines changed

data_browser/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ requires-python = ">=3.11"
1111
dependencies = [
1212
"py7zr>=0.22.0",
1313
"fsspec>=2024.9.0",
14-
"zstandard>=0.23.0",
14+
"zstandard>=0.18.0",
1515
"flask>=3.1.1",
1616
"pyarrow>=17.0.0",
1717
"gcsfs>=2024.9.0.post1",

experiments/evals/evals.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,6 @@
1818

1919
import logging
2020

21-
from experiments.evals.engine_configs import DEFAULT_LM_EVAL_MODEL_KWARGS
22-
from experiments.evals.resource_configs import SINGLE_TPU_V4_8, SINGLE_TPU_V6E_8
23-
from experiments.evals.task_configs import (
24-
BASE_GENERATION_TASKS,
25-
CORE_TASKS,
26-
CORE_TASKS_PLUS_LEADERBOARD,
27-
KEY_GENERATION_TASKS,
28-
KEY_MULTIPLE_CHOICE_TASKS,
29-
MMLU_0_SHOT,
30-
MMLU_5_SHOT,
31-
MMLU_PRO_5_SHOT,
32-
OPEN_LM_LEADERBOARD_GEN,
33-
OPEN_LM_LEADERBOARD_MCQ,
34-
)
3521
from fray.cluster.base import ResourceConfig
3622
from marin.evaluation.evaluation_config import (
3723
EvalTaskConfig,
@@ -50,6 +36,21 @@
5036
versioned,
5137
)
5238

39+
from experiments.evals.engine_configs import DEFAULT_LM_EVAL_MODEL_KWARGS
40+
from experiments.evals.resource_configs import SINGLE_TPU_V4_8, SINGLE_TPU_V6E_8
41+
from experiments.evals.task_configs import (
42+
BASE_GENERATION_TASKS,
43+
CORE_TASKS,
44+
CORE_TASKS_PLUS_LEADERBOARD,
45+
KEY_GENERATION_TASKS,
46+
KEY_MULTIPLE_CHOICE_TASKS,
47+
MMLU_0_SHOT,
48+
MMLU_5_SHOT,
49+
MMLU_PRO_5_SHOT,
50+
OPEN_LM_LEADERBOARD_GEN,
51+
OPEN_LM_LEADERBOARD_MCQ,
52+
)
53+
5354
logger = logging.getLogger(__name__)
5455

5556

@@ -71,10 +72,8 @@ def evaluate_helm(
7172
max_eval_instances: Maximum number of evaluation instances to run
7273
engine_kwargs: Additional keyword arguments to pass to the vLLM engine
7374
"""
74-
# Auto-detect device from resource config
7575
device = infer_device_from_resource_config(resource_config)
7676

77-
# Build ModelConfig
7877
model_config = ModelConfig(
7978
name=model_name,
8079
path=model_path,
@@ -83,7 +82,6 @@ def evaluate_helm(
8382
apply_chat_template=False,
8483
)
8584

86-
# Build InferencePoolConfig
8785
pool_config = InferencePoolConfig(
8886
resource_config=resource_config,
8987
model_config=model_config,
@@ -101,7 +99,7 @@ def evaluate_helm(
10199
evals=evals,
102100
max_eval_instances=max_eval_instances,
103101
),
104-
pip_dependency_groups=["eval"],
102+
pip_dependency_groups=["eval", "pip:crfm-helm@git+https://github.com/stanford-crfm/helm.git"],
105103
)
106104

107105

experiments/evals/test_helm_migration.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,31 @@
1717

1818
import os
1919

20-
from fray.cluster.base import ResourceConfig
20+
from fray.cluster.base import ResourceConfig, TpuConfig
21+
from marin.execution.executor import executor_main
22+
2123
from experiments.evals.evals import evaluate_helm
2224
from experiments.evals.task_configs import EvalTaskConfig
23-
from marin.execution.executor import executor_main
2425

25-
# Set local output prefix if not set
26+
# Set output prefix if not set
2627
if "MARIN_PREFIX" not in os.environ:
27-
os.environ["MARIN_PREFIX"] = "/tmp/marin-helm-test"
28+
os.environ["MARIN_PREFIX"] = "gs://marin-eu-west4/evals/helm-migration-test"
2829

29-
# Local test resource config
30-
local_config = ResourceConfig(
31-
cpu=2,
32-
ram="8g",
30+
# TPU test resource config
31+
tpu_config = ResourceConfig(
32+
cpu=16,
33+
ram="64g",
34+
disk="10g",
35+
device=TpuConfig(type="v5litepod-4", count=4),
3336
replicas=1,
37+
regions=["eu-west4"],
3438
)
3539

3640
step = evaluate_helm(
37-
model_name="test-baby-llama",
41+
model_name="timinar/baby-llama-58m",
3842
model_path="timinar/baby-llama-58m",
3943
evals=[EvalTaskConfig(name="mmlu", num_fewshot=0)],
40-
resource_config=local_config,
44+
resource_config=tpu_config,
4145
max_eval_instances=10,
4246
)
4347

lib/fray/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies = [
1414
"mergedeep",
1515
"pyyaml>=6.0",
1616
"typing-extensions>=4.0",
17-
"zstandard>=0.22.0",
17+
"zstandard>=0.18.0",
1818
]
1919

2020
[project.scripts]

lib/fray/src/fray/cluster/__init__.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,29 +75,48 @@ def set_current_cluster(cluster: Cluster) -> None:
7575
def current_cluster() -> Cluster:
7676
"""Get the current cluster from context.
7777
78-
If no cluster is set in context but FRAY_CLUSTER_SPEC environment variable is present,
79-
automatically creates and caches the cluster for this process.
78+
Auto-detection priority:
79+
1. Context variable (set via set_current_cluster())
80+
2. Ray cluster (if ray.is_initialized())
81+
3. FRAY_CLUSTER_SPEC environment variable
82+
4. LocalCluster (default fallback)
8083
8184
Returns:
82-
The cluster instance set via set_current_cluster() or auto-created from env var
85+
The cluster instance
8386
8487
Raises:
85-
RuntimeError: If no cluster has been set and FRAY_CLUSTER_SPEC is not present
88+
RuntimeError: If cluster creation fails
8689
"""
8790
cluster = _cluster_context.get()
8891
if cluster is not None:
8992
return cluster
9093

94+
# Auto-detect Ray execution
95+
try:
96+
import ray
97+
98+
if ray.is_initialized():
99+
from fray.cluster.ray.cluster import RayCluster
100+
101+
cluster = RayCluster()
102+
set_current_cluster(cluster)
103+
logger.info("Auto-detected Ray cluster from ray.is_initialized()")
104+
return cluster
105+
except ImportError:
106+
pass
107+
108+
# Check for FRAY_CLUSTER_SPEC
91109
cluster_spec = os.environ.get("FRAY_CLUSTER_SPEC")
92-
if cluster_spec is None:
93-
raise RuntimeError(
94-
"No cluster set in current context. Either call set_current_cluster() "
95-
"or set FRAY_CLUSTER_SPEC environment variable."
96-
)
110+
if cluster_spec is not None:
111+
cluster = create_cluster(cluster_spec)
112+
set_current_cluster(cluster)
113+
logger.info(f"Auto-created cluster from FRAY_CLUSTER_SPEC={cluster_spec}")
114+
return cluster
97115

98-
cluster = create_cluster(cluster_spec)
116+
# Default to LocalCluster
117+
cluster = LocalCluster()
99118
set_current_cluster(cluster)
100-
logger.info(f"Auto-created cluster from FRAY_CLUSTER_SPEC={cluster_spec}")
119+
logger.info("Using default LocalCluster")
101120
return cluster
102121

103122

lib/fray/src/fray/queue/http.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ class HttpQueueServer:
5858
"""HTTP server that manages multiple named queues.
5959
6060
Example:
61-
with HttpQueueServer(host="127.0.0.1", port=9999) as server:
61+
with HttpQueueServer(host="0.0.0.0", port=9999) as server:
6262
queue_a = server.new_queue("tasks")
6363
queue_b = server.new_queue("results")
6464
queue_a.push("task1")
6565
"""
6666

67-
def __init__(self, host: str = "127.0.0.1", port: int = 9999):
67+
def __init__(self, host: str = "0.0.0.0", port: int = 9999):
6868
self.host = host
6969
self.port = port
7070
self.queues: dict[str, MemoryQueue] = {}
@@ -116,11 +116,34 @@ def release(queue_name: str, lease_id: str = Body(...), timestamp: float = Body(
116116

117117
return app
118118

119+
def get_client_host(self) -> str:
120+
"""Get the hostname/IP that clients should use to connect.
121+
122+
When server binds to 0.0.0.0, clients need a specific hostname/IP.
123+
Returns the actual IP address using default route.
124+
"""
125+
if self.host == "0.0.0.0":
126+
import socket
127+
128+
# Get the IP address that clients should use by checking default route
129+
try:
130+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
131+
s.connect(("8.8.8.8", 80)) # doesn't actually send anything
132+
ip = s.getsockname()[0]
133+
s.close()
134+
return ip
135+
except Exception:
136+
# Fall back to localhost for local testing
137+
return "127.0.0.1"
138+
return self.host
139+
119140
def new_queue(self, name: str) -> "HttpQueue":
120141
"""Create or get a named queue, returns client."""
121142
if name not in self.queues:
122143
self.queues[name] = MemoryQueue()
123-
return HttpQueue(host=self.host, port=self.port, queue_name=name)
144+
# Use client-accessible host instead of bind host
145+
client_host = self.get_client_host()
146+
return HttpQueue(host=client_host, port=self.port, queue_name=name)
124147

125148
def __enter__(self):
126149
self.server_thread = ServerThread(self.server, self.host, self.port)

lib/marin/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ quality_dedup_consolidate = [
220220
"nltk>=3.8.1",
221221
"rbloom_gcs",
222222
"transformers",
223-
"zstandard>=0.23.0",
223+
"zstandard>=0.18.0",
224224
]
225225

226226
tokenize_train = [

lib/marin/src/marin/evaluation/backends/inference_pool.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,16 +202,65 @@ def base_url(self) -> str:
202202
return f"http://{self.config.proxy_host}:{self.config.proxy_port}/v1"
203203

204204
def wait_for_healthy(self, timeout: float = 300) -> None:
205+
"""Wait for the entire inference pool to be healthy.
206+
207+
Checks three components independently:
208+
1. Fray job status (detect worker crashes during startup)
209+
2. Proxy server health (FastAPI is responding)
210+
3. VLLM worker via queue round-trip (worker can process requests)
211+
"""
205212
start_time = time.time()
213+
proxy_healthy = False
214+
worker_healthy = False
215+
proxy_url = f"http://{self.config.proxy_host}:{self.config.proxy_port}"
216+
206217
while True:
218+
# Always check job status first - fail fast if worker crashed
207219
info = self.cluster.poll(self.job_id)
208-
if info.status == "running":
209-
logger.info("Pool job is running")
210-
break
211-
elif info.status in ["failed", "stopped"]:
212-
raise RuntimeError(f"Pool job failed: {info.error_message}")
213-
220+
if info.status in ["failed", "stopped"]:
221+
raise RuntimeError(f"Pool job failed during startup: {info.error_message}")
222+
223+
# Step 1: Check proxy server health independently
224+
if not proxy_healthy:
225+
try:
226+
response = requests.get(f"{proxy_url}/health", timeout=1)
227+
if response.status_code == 200:
228+
logger.info("Proxy server is healthy")
229+
proxy_healthy = True
230+
except requests.RequestException:
231+
pass # Proxy not ready yet
232+
233+
# Step 2: Check VLLM worker via queue round-trip
234+
if proxy_healthy and not worker_healthy:
235+
try:
236+
# Send a minimal test request through the queues to VLLM worker
237+
response = requests.post(
238+
f"{proxy_url}/v1/completions",
239+
json={
240+
"model": "default",
241+
"prompt": "test",
242+
"max_tokens": 1,
243+
"temperature": 0,
244+
},
245+
timeout=30,
246+
)
247+
if response.status_code == 200:
248+
logger.info("VLLM worker is healthy and responding via queues")
249+
worker_healthy = True
250+
return # Success - all components healthy!
251+
except requests.RequestException as e:
252+
logger.debug(f"VLLM worker health check failed: {e}")
253+
254+
# Check timeout
214255
if time.time() - start_time > timeout:
215-
raise TimeoutError("Pool job failed to start within timeout")
256+
issues = []
257+
if not proxy_healthy:
258+
issues.append("proxy server not responding")
259+
if not worker_healthy:
260+
issues.append("VLLM worker not responding")
261+
issues.append(f"job status: {info.status}")
262+
263+
raise TimeoutError(f"Pool failed to become healthy within {timeout}s. Issues: {', '.join(issues)}")
216264

217-
logger.info("Pool is healthy")
265+
# Wait before next check
266+
time.sleep(2)

lib/marin/src/marin/evaluation/backends/vllm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,10 @@ def start_vllm_server(
8787
)
8888

8989
# Add device specification
90-
if device != "auto":
91-
command += f"--device {device} "
92-
93-
# Add distributed backend for TPU
90+
# Note: vLLM 0.11.0 does not support --device flag
91+
# For TPU, use distributed executor backend
9492
if device == "tpu":
95-
command += "--distributed-executor-backend ray "
93+
command += "--device tpu --distributed-executor-backend ray "
9694

9795
# Add engine kwargs
9896
for key, value in engine_kwargs.items():
@@ -175,6 +173,7 @@ def vllm_server_worker(
175173
payload["model"] = model.name
176174

177175
url = f"{server_url}{endpoint}"
176+
logger.info(f"Sending request to vLLM at {url}")
178177
http_response = requests.post(
179178
url,
180179
json=payload,

lib/marin/src/marin/evaluation/evaluators/helm_evaluator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ def write_model_config_files(model: ModelConfig, base_url: str, prod_env_path: P
5454
os.makedirs(prod_env_path, exist_ok=True)
5555

5656
model_name: str = model.name
57-
print(f"Loading tokenizer for model: {model_name}", flush=True)
58-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
57+
# Use model.path for loading from HuggingFace, fallback to model.name if path is None
58+
model_path_or_name: str = model.path or model.name
59+
print(f"Loading tokenizer for model: {model_path_or_name}", flush=True)
60+
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name, trust_remote_code=True)
5961
print(f"Tokenizer loaded, max_length: {tokenizer.model_max_length}", flush=True)
6062

6163
content: dict = {
@@ -101,7 +103,7 @@ def write_model_config_files(model: ModelConfig, base_url: str, prod_env_path: P
101103
"name": model_name,
102104
"tokenizer_spec": {
103105
"class_name": "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer",
104-
"args": {"pretrained_model_name_or_path": model_name, "trust_remote_code": True},
106+
"args": {"pretrained_model_name_or_path": model_path_or_name, "trust_remote_code": True},
105107
},
106108
"prefix_token": tokenizer.bos_token,
107109
"end_of_text_token": tokenizer.eos_token,

0 commit comments

Comments
 (0)