Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 124 additions & 11 deletions exo/orchestration/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def __init__(
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
self.topology_inference_engines_pool: List[List[str]] = []
self.outstanding_requests = {}
# Batched sampling queues: key -> list[(request_id, shard_dict, logits)]
self._sample_queues: Dict[tuple, List[tuple[str, dict, np.ndarray]]] = {}
self._sample_flush_tasks: Dict[tuple, asyncio.Task] = {}
self._sample_max_batch_size: int = int(os.getenv("EXO_SAMPLE_MAX_BATCH", "8"))
self._sample_batch_timeout_ms: int = int(os.getenv("EXO_SAMPLE_BATCH_TIMEOUT_MS", "5"))

async def start(self, wait_for_peers: int = 0) -> None:
self.device_capabilities = await device_capabilities()
Expand Down Expand Up @@ -125,21 +130,19 @@ async def process_inference_result(
self.buffered_token_output[request_id] = ([], False)
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if shard.is_last_layer() and not is_finished:
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
await self.inference_engine.ensure_shard(shard)
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
forward = token.reshape(1, -1)
intermediate_result = [self.buffered_token_output[request_id][0][-1]]
# Enqueue for batched sampling instead of sampling immediately
await self._enqueue_sample(shard, request_id, result)
# Defer forwarding until batch sampling flush; use placeholders here
forward = None
intermediate_result = []
else:
forward = result
else:
await self.inference_engine.ensure_shard(shard)
is_finished = inference_state.get("is_finished", False)
intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
forward = result
if shard.is_last_layer():
if shard.is_last_layer() and shard.model_id == 'stable-diffusion-2-1-base':
self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))

Expand All @@ -148,10 +151,12 @@ async def process_inference_result(
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
self.outstanding_requests.pop(request_id)
else:
self.outstanding_requests[request_id] = "waiting"
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
# If we deferred due to batching, forwarding will be handled in the batch flush
if forward is not None:
self.outstanding_requests[request_id] = "waiting"
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))

return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
return np.array(self.buffered_token_output.get(request_id, ([], False))[0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result


async def process_prompt(
Expand Down Expand Up @@ -626,3 +631,111 @@ def handle_stable_diffusion(self, inference_state, result):
if progress[0] == progress[1]:
intermediate_result = result
return intermediate_result, inference_state

def _sample_key(self, shard, logits: np.ndarray) -> tuple:
# Key by model + shard layout + logits shape to ensure stackable batches
return (
shard.model_id,
shard.start_layer,
shard.end_layer,
shard.n_layers,
tuple(logits.shape[1:]), # ignore batch dim if present later
str(logits.dtype),
)

async def _enqueue_sample(self, shard, request_id: str, logits: np.ndarray) -> None:
key = self._sample_key(shard, logits)
if key not in self._sample_queues:
self._sample_queues[key] = []
# Store a lightweight shard dict to avoid capturing mutable object
shard_dict = shard.to_dict() if hasattr(shard, 'to_dict') else {
'model_id': shard.model_id, 'start_layer': shard.start_layer, 'end_layer': shard.end_layer, 'n_layers': shard.n_layers
}
self._sample_queues[key].append((request_id, shard_dict, logits))

# Flush immediately if we reached max batch size
if len(self._sample_queues[key]) >= self._sample_max_batch_size:
if key in self._sample_flush_tasks:
self._sample_flush_tasks[key].cancel()
self._sample_flush_tasks.pop(key, None)
await self._flush_sample_batch(key)
return

# Otherwise schedule a short timeout flush if not already scheduled
if key not in self._sample_flush_tasks:
async def _delayed_flush(k):
try:
await asyncio.sleep(self._sample_batch_timeout_ms / 1000.0)
await self._flush_sample_batch(k)
except asyncio.CancelledError:
pass
self._sample_flush_tasks[key] = asyncio.create_task(_delayed_flush(key))

async def _flush_sample_batch(self, key: tuple) -> None:
batch = self._sample_queues.get(key, [])
if not batch:
return
# Clear queue and cancel any pending flush task
self._sample_queues[key] = []
if key in self._sample_flush_tasks:
self._sample_flush_tasks[key].cancel()
self._sample_flush_tasks.pop(key, None)

# Unpack
request_ids, shard_dicts, logits_list = zip(*batch)
try:
engine_name = self.inference_engine.__class__.__name__
# Tinygrad sampling currently expects unbatched logits; fall back when batching >1
if engine_name.startswith('Tinygrad') and len(logits_list) > 1:
raise RuntimeError('tinygrad-sample-no-batch')
# Stack logits on batch dimension
stacked = np.stack(logits_list, axis=0)
tokens = await self.inference_engine.sample(stacked, temp=self.default_sample_temperature)
tokens = np.asarray(tokens).reshape(-1)
if engine_name.startswith('Tinygrad') and tokens.size != len(logits_list):
# Safety check: unexpected size, fall back
raise RuntimeError('tinygrad-sample-size-mismatch')
except Exception as e:
# Fallback: sample individually on error
if DEBUG >= 1: print(f"Batched sampling failed, falling back to per-request sampling: {e}")
tokens = []
for logits in logits_list:
tok = await self.inference_engine.sample(logits, temp=self.default_sample_temperature)
tokens.append(int(tok.item()))

# Process each sampled token
for req_id, shard_info, tok in zip(request_ids, shard_dicts, tokens.tolist()):
# Reconstruct shard
shard = Shard(
model_id=shard_info['model_id'],
start_layer=shard_info['start_layer'],
end_layer=shard_info['end_layer'],
n_layers=shard_info['n_layers'],
)
await self.inference_engine.ensure_shard(shard)
if req_id not in self.buffered_token_output:
self.buffered_token_output[req_id] = ([], False)
self.buffered_token_output[req_id][0].append(int(tok))
eos = getattr(self.inference_engine.tokenizer, 'eos_token_id', None)
is_finished = int(tok) == eos or len(self.buffered_token_output[req_id][0]) >= self.max_generate_tokens
if DEBUG >= 2:
print(f"[{req_id}] batched sample => {tok}, finished={is_finished}, buffered={len(self.buffered_token_output[req_id][0])}")

# Emit token to callbacks and peers
self.trigger_on_token_callbacks(req_id, [int(tok)], is_finished)
asyncio.create_task(self.broadcast_result(req_id, [int(tok)], is_finished))

# Continue the loop by forwarding the sampled token as next input
if not is_finished:
forward = np.array([[int(tok)]], dtype=np.int64)
self.outstanding_requests[req_id] = "waiting"
asyncio.create_task(self.forward_tensor(
shard,
forward,
req_id,
self.get_partition_index(offset=1),
None
))
else:
self.buffered_token_output[req_id] = (self.buffered_token_output[req_id][0], True)
self.outstanding_requests.pop(req_id, None)