diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 1281aa8ae..5e58cc538 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -54,6 +54,11 @@ def __init__( self.node_download_progress: Dict[str, RepoProgressEvent] = {} self.topology_inference_engines_pool: List[List[str]] = [] self.outstanding_requests = {} + # Batched sampling queues: key -> list[(request_id, shard_dict, logits)] + self._sample_queues: Dict[tuple, List[tuple[str, dict, np.ndarray]]] = {} + self._sample_flush_tasks: Dict[tuple, asyncio.Task] = {} + self._sample_max_batch_size: int = int(os.getenv("EXO_SAMPLE_MAX_BATCH", "8")) + self._sample_batch_timeout_ms: int = int(os.getenv("EXO_SAMPLE_BATCH_TIMEOUT_MS", "5")) async def start(self, wait_for_peers: int = 0) -> None: self.device_capabilities = await device_capabilities() @@ -125,13 +130,11 @@ async def process_inference_result( self.buffered_token_output[request_id] = ([], False) is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens if shard.is_last_layer() and not is_finished: - token = await self.inference_engine.sample(result, temp=self.default_sample_temperature) - await self.inference_engine.ensure_shard(shard) - self.buffered_token_output[request_id][0].append(token.item()) - is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens - if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") - forward = token.reshape(1, -1) - intermediate_result = [self.buffered_token_output[request_id][0][-1]] + # Enqueue for batched sampling instead of sampling immediately + await self._enqueue_sample(shard, request_id, result) + # Defer forwarding until batch sampling flush; use placeholders here + forward = None + intermediate_result = [] else: forward = result else: @@ -139,7 +142,7 @@ async def process_inference_result( is_finished = inference_state.get("is_finished", False) intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result) forward = result - if shard.is_last_layer(): + if shard.is_last_layer() and shard.model_id == 'stable-diffusion-2-1-base': self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished) asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished)) @@ -148,10 +151,12 @@ async def process_inference_result( self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) self.outstanding_requests.pop(request_id) else: - self.outstanding_requests[request_id] = "waiting" - asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state)) + # If we deferred due to batching, forwarding will be handled in the batch flush + if forward is not None: + self.outstanding_requests[request_id] = "waiting" + asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state)) - return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result + return np.array(self.buffered_token_output.get(request_id, ([], False))[0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result async def process_prompt( @@ -626,3 +631,111 @@ def handle_stable_diffusion(self, inference_state, result): if progress[0] == progress[1]: intermediate_result = result return intermediate_result, inference_state + + def _sample_key(self, shard, logits: np.ndarray) -> tuple: + # Key by model + shard layout + logits shape to ensure stackable batches + return ( + shard.model_id, + shard.start_layer, + shard.end_layer, + shard.n_layers, + tuple(logits.shape[1:]), # ignore batch dim if present later + str(logits.dtype), + ) + + async def _enqueue_sample(self, shard, request_id: str, logits: np.ndarray) -> None: + key = self._sample_key(shard, logits) + if key not in self._sample_queues: + self._sample_queues[key] = [] + # Store a lightweight shard dict to avoid capturing mutable object + shard_dict = shard.to_dict() if hasattr(shard, 'to_dict') else { + 'model_id': shard.model_id, 'start_layer': shard.start_layer, 'end_layer': shard.end_layer, 'n_layers': shard.n_layers + } + self._sample_queues[key].append((request_id, shard_dict, logits)) + + # Flush immediately if we reached max batch size + if len(self._sample_queues[key]) >= self._sample_max_batch_size: + if key in self._sample_flush_tasks: + self._sample_flush_tasks[key].cancel() + self._sample_flush_tasks.pop(key, None) + await self._flush_sample_batch(key) + return + + # Otherwise schedule a short timeout flush if not already scheduled + if key not in self._sample_flush_tasks: + async def _delayed_flush(k): + try: + await asyncio.sleep(self._sample_batch_timeout_ms / 1000.0) + await self._flush_sample_batch(k) + except asyncio.CancelledError: + pass + self._sample_flush_tasks[key] = asyncio.create_task(_delayed_flush(key)) + + async def _flush_sample_batch(self, key: tuple) -> None: + batch = self._sample_queues.get(key, []) + if not batch: + return + # Clear queue and cancel any pending flush task + self._sample_queues[key] = [] + if key in self._sample_flush_tasks: + self._sample_flush_tasks[key].cancel() + self._sample_flush_tasks.pop(key, None) + + # Unpack + request_ids, shard_dicts, logits_list = zip(*batch) + try: + engine_name = self.inference_engine.__class__.__name__ + # Tinygrad sampling currently expects unbatched logits; fall back when batching >1 + if engine_name.startswith('Tinygrad') and len(logits_list) > 1: + raise RuntimeError('tinygrad-sample-no-batch') + # Stack logits on batch dimension + stacked = np.stack(logits_list, axis=0) + tokens = await self.inference_engine.sample(stacked, temp=self.default_sample_temperature) + tokens = np.asarray(tokens).reshape(-1) + if engine_name.startswith('Tinygrad') and tokens.size != len(logits_list): + # Safety check: unexpected size, fall back + raise RuntimeError('tinygrad-sample-size-mismatch') + except Exception as e: + # Fallback: sample individually on error + if DEBUG >= 1: print(f"Batched sampling failed, falling back to per-request sampling: {e}") + tokens = [] + for logits in logits_list: + tok = await self.inference_engine.sample(logits, temp=self.default_sample_temperature) + tokens.append(int(tok.item())) + + # Process each sampled token + for req_id, shard_info, tok in zip(request_ids, shard_dicts, tokens.tolist()): + # Reconstruct shard + shard = Shard( + model_id=shard_info['model_id'], + start_layer=shard_info['start_layer'], + end_layer=shard_info['end_layer'], + n_layers=shard_info['n_layers'], + ) + await self.inference_engine.ensure_shard(shard) + if req_id not in self.buffered_token_output: + self.buffered_token_output[req_id] = ([], False) + self.buffered_token_output[req_id][0].append(int(tok)) + eos = getattr(self.inference_engine.tokenizer, 'eos_token_id', None) + is_finished = int(tok) == eos or len(self.buffered_token_output[req_id][0]) >= self.max_generate_tokens + if DEBUG >= 2: + print(f"[{req_id}] batched sample => {tok}, finished={is_finished}, buffered={len(self.buffered_token_output[req_id][0])}") + + # Emit token to callbacks and peers + self.trigger_on_token_callbacks(req_id, [int(tok)], is_finished) + asyncio.create_task(self.broadcast_result(req_id, [int(tok)], is_finished)) + + # Continue the loop by forwarding the sampled token as next input + if not is_finished: + forward = np.array([[int(tok)]], dtype=np.int64) + self.outstanding_requests[req_id] = "waiting" + asyncio.create_task(self.forward_tensor( + shard, + forward, + req_id, + self.get_partition_index(offset=1), + None + )) + else: + self.buffered_token_output[req_id] = (self.buffered_token_output[req_id][0], True) + self.outstanding_requests.pop(req_id, None)