1717from typing import Any , Dict
1818
1919import numpy as np
20+ from huggingface_hub .constants import HF_HUB_CACHE
2021from PIL import Image
2122from transformers import AutoTokenizer
2223from transformers .pipelines import Pipeline
2324
24- from optimum .pipelines import pipeline
25+ from optimum .pipelines import pipeline as optimum_pipeline
26+ from optimum .utils .testing_utils import remove_directory
27+
28+
29+ GENERATE_KWARGS = {"max_new_tokens" : 10 , "min_new_tokens" : 5 , "do_sample" : True }
2530
2631
2732class ORTPipelineTest (unittest .TestCase ):
@@ -33,20 +38,18 @@ def _create_dummy_text(self) -> str:
3338
3439 def _create_dummy_image (self ) -> Image .Image :
3540 """Create dummy image input for image-based tasks"""
36- # Create a simple RGB image
3741 np_image = np .random .randint (0 , 256 , (224 , 224 , 3 ), dtype = np .uint8 )
3842 return Image .fromarray (np_image )
3943
4044 def _create_dummy_audio (self ) -> Dict [str , Any ]:
4145 """Create dummy audio input for audio-based tasks"""
42- # Create a dummy audio array (16kHz sample rate, 1 second)
4346 sample_rate = 16000
4447 audio_array = np .random .randn (sample_rate ).astype (np .float32 )
4548 return {"array" : audio_array , "sampling_rate" : sample_rate }
4649
4750 def test_text_classification_pipeline (self ):
4851 """Test text classification ORT pipeline"""
49- pipe = pipeline (task = "text-classification" , accelerator = "ort" )
52+ pipe = optimum_pipeline (task = "text-classification" , accelerator = "ort" )
5053 self .assertIsInstance (pipe , Pipeline )
5154 text = self ._create_dummy_text ()
5255 result = pipe (text )
@@ -58,7 +61,7 @@ def test_text_classification_pipeline(self):
5861
5962 def test_token_classification_pipeline (self ):
6063 """Test token classification ORT pipeline"""
61- pipe = pipeline (task = "token-classification" , accelerator = "ort" )
64+ pipe = optimum_pipeline (task = "token-classification" , accelerator = "ort" )
6265 self .assertIsInstance (pipe , Pipeline )
6366 text = self ._create_dummy_text ()
6467 result = pipe (text )
@@ -71,7 +74,7 @@ def test_token_classification_pipeline(self):
7174
7275 def test_question_answering_pipeline (self ):
7376 """Test question answering ORT pipeline"""
74- pipe = pipeline (task = "question-answering" , accelerator = "ort" )
77+ pipe = optimum_pipeline (task = "question-answering" , accelerator = "ort" )
7578 self .assertIsInstance (pipe , Pipeline )
7679 question = "What animal jumps?"
7780 context = "The quick brown fox jumps over the lazy dog."
@@ -85,7 +88,7 @@ def test_question_answering_pipeline(self):
8588
8689 def test_fill_mask_pipeline (self ):
8790 """Test fill mask ORT pipeline"""
88- pipe = pipeline (task = "fill-mask" , accelerator = "ort" )
91+ pipe = optimum_pipeline (task = "fill-mask" , accelerator = "ort" )
8992 self .assertIsInstance (pipe , Pipeline )
9093 text = "The weather is <mask> today."
9194 result = pipe (text )
@@ -97,7 +100,7 @@ def test_fill_mask_pipeline(self):
97100
98101 def test_feature_extraction_pipeline (self ):
99102 """Test feature extraction ORT pipeline"""
100- pipe = pipeline (task = "feature-extraction" , accelerator = "ort" )
103+ pipe = optimum_pipeline (task = "feature-extraction" , accelerator = "ort" )
101104 self .assertIsInstance (pipe , Pipeline )
102105 text = self ._create_dummy_text ()
103106 result = pipe (text )
@@ -108,10 +111,10 @@ def test_feature_extraction_pipeline(self):
108111
109112 def test_text_generation_pipeline (self ):
110113 """Test text generation ORT pipeline"""
111- pipe = pipeline (task = "text-generation" , accelerator = "ort" )
114+ pipe = optimum_pipeline (task = "text-generation" , accelerator = "ort" )
112115 self .assertIsInstance (pipe , Pipeline )
113116 text = "The future of AI is"
114- result = pipe (text , max_new_tokens = 50 , do_sample = False )
117+ result = pipe (text , ** GENERATE_KWARGS )
115118
116119 self .assertIsInstance (result , list )
117120 self .assertGreater (len (result ), 0 )
@@ -120,40 +123,40 @@ def test_text_generation_pipeline(self):
120123
121124 def test_summarization_pipeline (self ):
122125 """Test summarization ORT pipeline"""
123- pipe = pipeline (task = "summarization" , accelerator = "ort" )
126+ pipe = optimum_pipeline (task = "summarization" , accelerator = "ort" )
124127 self .assertIsInstance (pipe , Pipeline )
125128 text = "The quick brown fox jumps over the lazy dog."
126- result = pipe (text , max_new_tokens = 50 , min_new_tokens = 10 , do_sample = False )
129+ result = pipe (text , ** GENERATE_KWARGS )
127130
128131 self .assertIsInstance (result , list )
129132 self .assertGreater (len (result ), 0 )
130133 self .assertIn ("summary_text" , result [0 ])
131134
132135 def test_translation_pipeline (self ):
133136 """Test translation ORT pipeline"""
134- pipe = pipeline (task = "translation_en_to_de" , accelerator = "ort" )
137+ pipe = optimum_pipeline (task = "translation_en_to_de" , accelerator = "ort" )
135138 self .assertIsInstance (pipe , Pipeline )
136139 text = "Hello, how are you?"
137- result = pipe (text , max_new_tokens = 50 )
140+ result = pipe (text , ** GENERATE_KWARGS )
138141
139142 self .assertIsInstance (result , list )
140143 self .assertGreater (len (result ), 0 )
141144 self .assertIn ("translation_text" , result [0 ])
142145
143146 def test_text2text_generation_pipeline (self ):
144147 """Test text2text generation ORT pipeline"""
145- pipe = pipeline (task = "text2text-generation" , accelerator = "ort" )
148+ pipe = optimum_pipeline (task = "text2text-generation" , accelerator = "ort" )
146149 self .assertIsInstance (pipe , Pipeline )
147150 text = "translate English to German: Hello, how are you?"
148- result = pipe (text , max_new_tokens = 50 )
151+ result = pipe (text , ** GENERATE_KWARGS )
149152
150153 self .assertIsInstance (result , list )
151154 self .assertGreater (len (result ), 0 )
152155 self .assertIn ("generated_text" , result [0 ])
153156
154157 def test_zero_shot_classification_pipeline (self ):
155158 """Test zero shot classification ORT pipeline"""
156- pipe = pipeline (task = "zero-shot-classification" , accelerator = "ort" )
159+ pipe = optimum_pipeline (task = "zero-shot-classification" , accelerator = "ort" )
157160 self .assertIsInstance (pipe , Pipeline )
158161 text = "This is a great movie with excellent acting."
159162 candidate_labels = ["positive" , "negative" , "neutral" ]
@@ -166,7 +169,7 @@ def test_zero_shot_classification_pipeline(self):
166169
167170 def test_image_classification_pipeline (self ):
168171 """Test image classification ORT pipeline"""
169- pipe = pipeline (task = "image-classification" , accelerator = "ort" )
172+ pipe = optimum_pipeline (task = "image-classification" , accelerator = "ort" )
170173 self .assertIsInstance (pipe , Pipeline )
171174 image = self ._create_dummy_image ()
172175 result = pipe (image )
@@ -178,7 +181,7 @@ def test_image_classification_pipeline(self):
178181
179182 def test_image_segmentation_pipeline (self ):
180183 """Test image segmentation ORT pipeline"""
181- pipe = pipeline (task = "image-segmentation" , accelerator = "ort" )
184+ pipe = optimum_pipeline (task = "image-segmentation" , accelerator = "ort" )
182185 self .assertIsInstance (pipe , Pipeline )
183186 image = self ._create_dummy_image ()
184187 result = pipe (image )
@@ -191,18 +194,18 @@ def test_image_segmentation_pipeline(self):
191194
192195 def test_image_to_text_pipeline (self ):
193196 """Test image to text ORT pipeline"""
194- pipe = pipeline (task = "image-to-text" , accelerator = "ort" )
197+ pipe = optimum_pipeline (task = "image-to-text" , accelerator = "ort" )
195198 self .assertIsInstance (pipe , Pipeline )
196199 image = self ._create_dummy_image ()
197- result = pipe (image )
200+ result = pipe (image , generate_kwargs = GENERATE_KWARGS )
198201
199202 self .assertIsInstance (result , list )
200203 self .assertGreater (len (result ), 0 )
201204 self .assertIn ("generated_text" , result [0 ])
202205
203206 def test_image_to_image_pipeline (self ):
204207 """Test image to image ORT pipeline"""
205- pipe = pipeline (task = "image-to-image" , accelerator = "ort" )
208+ pipe = optimum_pipeline (task = "image-to-image" , accelerator = "ort" )
206209 self .assertIsInstance (pipe , Pipeline )
207210 image = self ._create_dummy_image ()
208211 result = pipe (image )
@@ -212,16 +215,16 @@ def test_image_to_image_pipeline(self):
212215 # TODO: Enable when fixed in optimum-onnx
213216 # def test_automatic_speech_recognition_pipeline(self):
214217 # """Test automatic speech recognition ORT pipeline"""
215- # pipe = pipeline (task="automatic-speech-recognition", accelerator="ort")
218+ # pipe = optimum_pipeline (task="automatic-speech-recognition", accelerator="ort")
216219 # audio = self._create_dummy_audio()
217- # result = pipe(audio)
220+ # result = pipe(audio, generate_kwargs=GENERATE_KWARGS )
218221
219222 # self.assertIsInstance(result, dict)
220223 # self.assertIn("text", result)
221224
222225 def test_audio_classification_pipeline (self ):
223226 """Test audio classification ORT pipeline"""
224- pipe = pipeline (task = "audio-classification" , accelerator = "ort" )
227+ pipe = optimum_pipeline (task = "audio-classification" , accelerator = "ort" )
225228 self .assertIsInstance (pipe , Pipeline )
226229 audio = self ._create_dummy_audio ()
227230 result = pipe (audio )
@@ -237,7 +240,7 @@ def test_pipeline_with_ort_model(self):
237240
238241 tokenizer = AutoTokenizer .from_pretrained ("distilbert-base-cased" )
239242 model = ORTModelForFeatureExtraction .from_pretrained ("distilbert-base-cased" , export = True )
240- pipe = pipeline (task = "feature-extraction" , model = model , tokenizer = tokenizer , accelerator = "ort" )
243+ pipe = optimum_pipeline (task = "feature-extraction" , model = model , tokenizer = tokenizer , accelerator = "ort" )
241244 self .assertIsInstance (pipe , Pipeline )
242245 text = self ._create_dummy_text ()
243246 result = pipe (text )
@@ -246,9 +249,9 @@ def test_pipeline_with_ort_model(self):
246249 self .assertIsInstance (result [0 ], list )
247250 self .assertIsInstance (result [0 ][0 ], list )
248251
249- def test_pipeline_with_custom_model_id (self ):
252+ def test_pipeline_with_model_id (self ):
250253 """Test ORT pipeline with a custom model id"""
251- pipe = pipeline (task = "feature-extraction" , model = "distilbert-base-cased" , accelerator = "ort" )
254+ pipe = optimum_pipeline (task = "feature-extraction" , model = "distilbert-base-cased" , accelerator = "ort" )
252255 self .assertIsInstance (pipe , Pipeline )
253256 text = self ._create_dummy_text ()
254257 result = pipe (text )
@@ -259,15 +262,18 @@ def test_pipeline_with_custom_model_id(self):
259262 def test_pipeline_with_invalid_task (self ):
260263 """Test ORT pipeline with an unsupported task"""
261264 with self .assertRaises (KeyError ) as context :
262- _ = pipeline (task = "invalid-task" , accelerator = "ort" )
265+ _ = optimum_pipeline (task = "invalid-task" , accelerator = "ort" )
263266 self .assertIn ("Unknown task invalid-task" , str (context .exception ))
264267
265268 def test_pipeline_with_invalid_accelerator (self ):
266269 """Test ORT pipeline with an unsupported accelerator"""
267270 with self .assertRaises (ValueError ) as context :
268- _ = pipeline (task = "text-classification " , accelerator = "invalid-accelerator" )
271+ _ = optimum_pipeline (task = "feature-extraction " , accelerator = "invalid-accelerator" )
269272 self .assertIn ("Accelerator invalid-accelerator not recognized" , str (context .exception ))
270273
274+ def tearDown (self ):
275+ remove_directory (HF_HUB_CACHE )
276+
271277
272278if __name__ == "__main__" :
273279 unittest .main ()
0 commit comments