Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/coffea/ml_tools/triton_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# For python niceties
import time
import warnings
from typing import Optional

Expand Down Expand Up @@ -56,6 +57,8 @@ class triton_wrapper(nonserializable_attribute, numpy_call_wrapper):

batch_size_fallback = 10 # Fall back should batch size not be determined.
http_client_concurrency = 12 # TODO: check whether this value is optimum
max_retry_attempts = 5
retry_jitter_base_ms = 100

def __init__(
self, model_url: str, client_args: Optional[dict] = None, batch_size=-1
Expand Down Expand Up @@ -310,13 +313,8 @@ def _get_infer_shape(name):
).astype(mtype)
)

# Making request to server
request = self.client.infer(
self.model,
model_version=self.version,
inputs=infer_inputs,
outputs=infer_outputs,
)
# Running the request with fall back
request = self.run_infer(infer_inputs, infer_outputs)
if output is None:
output = {
o: request.as_numpy(o)[start_idx:stop_idx] for o in output_list
Expand All @@ -336,3 +334,20 @@ def _get_infer_shape(name):
}

return {k: v[:orig_len] for k, v in output.items()}

def run_infer(self, inputs, outputs, attempt=0):
"""Thin wrapper around tritonclient.infer to automatic retry with backoff+jitter on inference server failures"""
try:
return self.client.infer(
self.model, self.model_version, inputs=inputs, outputs=outputs
)
except tritonclient.utils.InferenceServerExecption as err:
if attempt > self.max_retry_attempts:
raise err
else:
# Retry backoff + full jitter:
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
time.sleep(
numpy.random.rand() * self.retry_jitter_base_ms * (2**attempt)
)
return self.run_infer(inputs, outputs, attempt + 1)
Loading