@@ -2762,8 +2762,32 @@ def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bo
27622762 self .assertIsInstance (outputs [0 ]["generated_text" ], str )
27632763 self .assertTrue (len (outputs [0 ]["generated_text" ]) > len (text ))
27642764
2765+ if model_arch == "llama" :
2766+ with tempfile .TemporaryDirectory () as tmpdir :
2767+ pipe .save_pretrained (tmpdir )
2768+ model_kwargs = {"use_cache" : use_cache , "use_io_binding" : use_io_binding }
2769+ pipe = pipeline (
2770+ "text-generation" ,
2771+ model = tmpdir ,
2772+ model_kwargs = model_kwargs ,
2773+ accelerator = "ort" ,
2774+ )
2775+ outputs_local_model = pipe (text )
2776+ self .assertEqual (outputs [0 ]["generated_text" ], outputs_local_model [0 ]["generated_text" ])
2777+
27652778 gc .collect ()
27662779
2780+ def test_load_pipeline (self ):
2781+ pipe = pipeline (
2782+ "text-generation" ,
2783+ model = "optimum-internal-testing/tiny-random-llama" ,
2784+ revision = "onnx" ,
2785+ accelerator = "ort" ,
2786+ )
2787+
2788+ outputs = pipe ("this is an example input" )
2789+ self .assertIsInstance (outputs [0 ]["generated_text" ], str )
2790+
27672791 @pytest .mark .run_in_series
27682792 def test_pipeline_model_is_none (self ):
27692793 pipe = pipeline ("text-generation" )
@@ -4152,8 +4176,30 @@ def test_pipeline_text_generation(self, test_name: str, model_arch: str, use_cac
41524176 self .assertEqual (pipe .device , onnx_model .device )
41534177 self .assertIsInstance (outputs [0 ]["translation_text" ], str )
41544178
4179+ if model_arch == "t5" :
4180+ with tempfile .TemporaryDirectory () as tmpdir :
4181+ pipe .save_pretrained (tmpdir )
4182+ model_kwargs = {"use_cache" : use_cache }
4183+ pipe = pipeline (
4184+ "translation_en_to_de" ,
4185+ model = tmpdir ,
4186+ model_kwargs = model_kwargs ,
4187+ accelerator = "ort" ,
4188+ )
4189+ outputs_local_model = pipe (text )
4190+ self .assertEqual (outputs [0 ]["translation_text" ], outputs_local_model [0 ]["translation_text" ])
4191+
41554192 gc .collect ()
41564193
4194+ def test_load_pipeline (self ):
4195+ pipe = pipeline (
4196+ "text2text-generation" ,
4197+ model = "echarlaix/t5-small-onnx" ,
4198+ accelerator = "ort" ,
4199+ )
4200+ outputs = pipe ("this is an example input" )
4201+ self .assertIsInstance (outputs [0 ]["generated_text" ], str )
4202+
41574203 @pytest .mark .run_in_series
41584204 def test_pipeline_model_is_none (self ):
41594205 # Text2text generation
0 commit comments