@@ -163,32 +163,33 @@ def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype) -> None:
163
163
assert a .dtype == dtype
164
164
165
165
166
- @pytest .mark .parametrize ("torch_dtype" , [torch .float32 , torch .bfloat16 ])
167
- def test_pipeline_predict (torch_dtype : str ):
166
+ @pytest .mark .parametrize ("model_dtype" , [torch .float32 , torch .bfloat16 ])
167
+ @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .bfloat16 ])
168
+ def test_pipeline_predict (model_dtype : torch .dtype , input_dtype : torch .dtype ):
168
169
pipeline = ChronosPipeline .from_pretrained (
169
170
Path (__file__ ).parent / "dummy-chronos-model" ,
170
171
device_map = "cpu" ,
171
- torch_dtype = torch_dtype ,
172
+ torch_dtype = model_dtype ,
172
173
)
173
- context = 10 * torch .rand (size = (4 , 16 )) + 10
174
+ context = 10 * torch .rand (size = (4 , 16 ), dtype = input_dtype ) + 10
174
175
175
176
# input: tensor of shape (batch_size, context_length)
176
177
177
178
samples = pipeline .predict (context , num_samples = 12 , prediction_length = 3 )
178
- validate_tensor (samples , shape = (4 , 12 , 3 ), dtype = torch . float32 )
179
+ validate_tensor (samples , shape = (4 , 12 , 3 ), dtype = input_dtype )
179
180
180
181
with pytest .raises (ValueError ):
181
182
samples = pipeline .predict (context , num_samples = 7 , prediction_length = 65 )
182
183
183
184
samples = pipeline .predict (
184
185
context , num_samples = 7 , prediction_length = 65 , limit_prediction_length = False
185
186
)
186
- validate_tensor (samples , shape = (4 , 7 , 65 ), dtype = torch . float32 )
187
+ validate_tensor (samples , shape = (4 , 7 , 65 ), dtype = input_dtype )
187
188
188
189
# input: batch_size-long list of tensors of shape (context_length,)
189
190
190
191
samples = pipeline .predict (list (context ), num_samples = 12 , prediction_length = 3 )
191
- validate_tensor (samples , shape = (4 , 12 , 3 ), dtype = torch . float32 )
192
+ validate_tensor (samples , shape = (4 , 12 , 3 ), dtype = input_dtype )
192
193
193
194
with pytest .raises (ValueError ):
194
195
samples = pipeline .predict (list (context ), num_samples = 7 , prediction_length = 65 )
@@ -199,12 +200,12 @@ def test_pipeline_predict(torch_dtype: str):
199
200
prediction_length = 65 ,
200
201
limit_prediction_length = False ,
201
202
)
202
- validate_tensor (samples , shape = (4 , 7 , 65 ), dtype = torch . float32 )
203
+ validate_tensor (samples , shape = (4 , 7 , 65 ), dtype = input_dtype )
203
204
204
205
# input: tensor of shape (context_length,)
205
206
206
207
samples = pipeline .predict (context [0 , ...], num_samples = 12 , prediction_length = 3 )
207
- validate_tensor (samples , shape = (1 , 12 , 3 ), dtype = torch . float32 )
208
+ validate_tensor (samples , shape = (1 , 12 , 3 ), dtype = input_dtype )
208
209
209
210
with pytest .raises (ValueError ):
210
211
samples = pipeline .predict (context [0 , ...], num_samples = 7 , prediction_length = 65 )
@@ -215,40 +216,41 @@ def test_pipeline_predict(torch_dtype: str):
215
216
prediction_length = 65 ,
216
217
limit_prediction_length = False ,
217
218
)
218
- validate_tensor (samples , shape = (1 , 7 , 65 ), dtype = torch . float32 )
219
+ validate_tensor (samples , shape = (1 , 7 , 65 ), dtype = input_dtype )
219
220
220
221
221
- @pytest .mark .parametrize ("torch_dtype" , [torch .float32 , torch .bfloat16 ])
222
- def test_pipeline_embed (torch_dtype : str ):
222
+ @pytest .mark .parametrize ("model_dtype" , [torch .float32 , torch .bfloat16 ])
223
+ @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .bfloat16 ])
224
+ def test_pipeline_embed (model_dtype : torch .dtype , input_dtype : torch .dtype ):
223
225
pipeline = ChronosPipeline .from_pretrained (
224
226
Path (__file__ ).parent / "dummy-chronos-model" ,
225
227
device_map = "cpu" ,
226
- torch_dtype = torch_dtype ,
228
+ torch_dtype = model_dtype ,
227
229
)
228
230
d_model = pipeline .model .model .config .d_model
229
- context = 10 * torch .rand (size = (4 , 16 )) + 10
231
+ context = 10 * torch .rand (size = (4 , 16 ), dtype = input_dtype ) + 10
230
232
expected_embed_length = 16 + (1 if pipeline .model .config .use_eos_token else 0 )
231
233
232
234
# input: tensor of shape (batch_size, context_length)
233
235
234
236
embedding , scale = pipeline .embed (context )
235
237
validate_tensor (
236
- embedding , shape = (4 , expected_embed_length , d_model ), dtype = torch_dtype
238
+ embedding , shape = (4 , expected_embed_length , d_model ), dtype = model_dtype
237
239
)
238
240
validate_tensor (scale , shape = (4 ,), dtype = torch .float32 )
239
241
240
242
# input: batch_size-long list of tensors of shape (context_length,)
241
243
242
244
embedding , scale = pipeline .embed (list (context ))
243
245
validate_tensor (
244
- embedding , shape = (4 , expected_embed_length , d_model ), dtype = torch_dtype
246
+ embedding , shape = (4 , expected_embed_length , d_model ), dtype = model_dtype
245
247
)
246
248
validate_tensor (scale , shape = (4 ,), dtype = torch .float32 )
247
249
248
250
# input: tensor of shape (context_length,)
249
251
embedding , scale = pipeline .embed (context [0 , ...])
250
252
validate_tensor (
251
- embedding , shape = (1 , expected_embed_length , d_model ), dtype = torch_dtype
253
+ embedding , shape = (1 , expected_embed_length , d_model ), dtype = model_dtype
252
254
)
253
255
validate_tensor (scale , shape = (1 ,), dtype = torch .float32 )
254
256
0 commit comments