Skip to content

Commit 92c178d

Browse files
authored
Fix ORT pipeline (#2274)
* fix pipeline * add test * add test * add test * only test for targeted architecture
1 parent eb6d9ed commit 92c178d

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

optimum/pipelines/pipelines_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def load_ort_pipeline(
244244
model_id = SUPPORTED_TASKS[targeted_task]["default"]
245245
model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained(model_id, export=True)
246246
elif isinstance(model, str):
247+
model_id = model
247248
model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained(
248249
model, revision=revision, subfolder=subfolder, token=token, **model_kwargs
249250
)

tests/onnxruntime/test_modeling.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2762,8 +2762,32 @@ def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bo
27622762
self.assertIsInstance(outputs[0]["generated_text"], str)
27632763
self.assertTrue(len(outputs[0]["generated_text"]) > len(text))
27642764

2765+
if model_arch == "llama":
2766+
with tempfile.TemporaryDirectory() as tmpdir:
2767+
pipe.save_pretrained(tmpdir)
2768+
model_kwargs = {"use_cache": use_cache, "use_io_binding": use_io_binding}
2769+
pipe = pipeline(
2770+
"text-generation",
2771+
model=tmpdir,
2772+
model_kwargs=model_kwargs,
2773+
accelerator="ort",
2774+
)
2775+
outputs_local_model = pipe(text)
2776+
self.assertEqual(outputs[0]["generated_text"], outputs_local_model[0]["generated_text"])
2777+
27652778
gc.collect()
27662779

2780+
def test_load_pipeline(self):
2781+
pipe = pipeline(
2782+
"text-generation",
2783+
model="optimum-internal-testing/tiny-random-llama",
2784+
revision="onnx",
2785+
accelerator="ort",
2786+
)
2787+
2788+
outputs = pipe("this is an example input")
2789+
self.assertIsInstance(outputs[0]["generated_text"], str)
2790+
27672791
@pytest.mark.run_in_series
27682792
def test_pipeline_model_is_none(self):
27692793
pipe = pipeline("text-generation")
@@ -4152,8 +4176,30 @@ def test_pipeline_text_generation(self, test_name: str, model_arch: str, use_cac
41524176
self.assertEqual(pipe.device, onnx_model.device)
41534177
self.assertIsInstance(outputs[0]["translation_text"], str)
41544178

4179+
if model_arch == "t5":
4180+
with tempfile.TemporaryDirectory() as tmpdir:
4181+
pipe.save_pretrained(tmpdir)
4182+
model_kwargs = {"use_cache": use_cache}
4183+
pipe = pipeline(
4184+
"translation_en_to_de",
4185+
model=tmpdir,
4186+
model_kwargs=model_kwargs,
4187+
accelerator="ort",
4188+
)
4189+
outputs_local_model = pipe(text)
4190+
self.assertEqual(outputs[0]["translation_text"], outputs_local_model[0]["translation_text"])
4191+
41554192
gc.collect()
41564193

4194+
def test_load_pipeline(self):
4195+
pipe = pipeline(
4196+
"text2text-generation",
4197+
model="echarlaix/t5-small-onnx",
4198+
accelerator="ort",
4199+
)
4200+
outputs = pipe("this is an example input")
4201+
self.assertIsInstance(outputs[0]["generated_text"], str)
4202+
41574203
@pytest.mark.run_in_series
41584204
def test_pipeline_model_is_none(self):
41594205
# Text2text generation

0 commit comments

Comments
 (0)