diff --git a/src/coffea/ml_tools/triton_wrapper.py b/src/coffea/ml_tools/triton_wrapper.py index 6e8769b71..6b4c45be2 100644 --- a/src/coffea/ml_tools/triton_wrapper.py +++ b/src/coffea/ml_tools/triton_wrapper.py @@ -1,4 +1,5 @@ # For python niceties +import time import warnings from typing import Optional @@ -56,6 +57,8 @@ class triton_wrapper(nonserializable_attribute, numpy_call_wrapper): batch_size_fallback = 10 # Fall back should batch size not be determined. http_client_concurrency = 12 # TODO: check whether this value is optimum + max_retry_attempts = 5 + retry_jitter_base_ms = 100 def __init__( self, model_url: str, client_args: Optional[dict] = None, batch_size=-1 @@ -310,13 +313,8 @@ def _get_infer_shape(name): ).astype(mtype) ) - # Making request to server - request = self.client.infer( - self.model, - model_version=self.version, - inputs=infer_inputs, - outputs=infer_outputs, - ) + # Running the request with fall back + request = self.run_infer(infer_inputs, infer_outputs) if output is None: output = { o: request.as_numpy(o)[start_idx:stop_idx] for o in output_list @@ -336,3 +334,20 @@ def _get_infer_shape(name): } return {k: v[:orig_len] for k, v in output.items()} + + def run_infer(self, inputs, outputs, attempt=0): + """Thin wrapper around tritonclient.infer to automatic retry with backoff+jitter on inference server failures""" + try: + return self.client.infer( + self.model, self.model_version, inputs=inputs, outputs=outputs + ) + except tritonclient.utils.InferenceServerExecption as err: + if attempt > self.max_retry_attempts: + raise err + else: + # Retry backoff + full jitter: + # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + time.sleep( + numpy.random.rand() * self.retry_jitter_base_ms * (2**attempt) + ) + return self.run_infer(inputs, outputs, attempt + 1)