14
14
15
15
import csv
16
16
import json
17
+ import os
17
18
18
19
import mock
19
20
import numpy as np
30
31
31
32
32
33
class DummyModel (nn .Module ):
33
-
34
- def __init__ (self , ):
34
+ def __init__ (self ):
35
35
super (DummyModel , self ).__init__ ()
36
36
37
37
def forward (self , x ):
@@ -58,7 +58,9 @@ def eia_inference_handler():
58
58
59
59
60
60
def test_default_model_fn (inference_handler ):
61
- with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
61
+ with mock .patch (
62
+ "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
63
+ ) as mock_os :
62
64
mock_os .getenv .return_value = "true"
63
65
mock_os .path .join .return_value = "model_dir"
64
66
mock_os .path .exists .return_value = True
@@ -69,7 +71,9 @@ def test_default_model_fn(inference_handler):
69
71
70
72
71
73
def test_default_model_fn_unknown_name (inference_handler ):
72
- with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
74
+ with mock .patch (
75
+ "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
76
+ ) as mock_os :
73
77
mock_os .getenv .return_value = "false"
74
78
mock_os .path .join .return_value = "model_dir"
75
79
mock_os .path .isfile .return_value = True
@@ -80,18 +84,25 @@ def test_default_model_fn_unknown_name(inference_handler):
80
84
assert model is not None
81
85
82
86
83
- @pytest .mark .parametrize ("listdir_return_value" , [["abcd.py" , "efgh.txt" , "ijkl.bin" ], ["abcd.pt" , "efgh.pth" ]])
87
+ @pytest .mark .parametrize (
88
+ "listdir_return_value" , [["abcd.py" , "efgh.txt" , "ijkl.bin" ], ["abcd.pt" , "efgh.pth" ]]
89
+ )
84
90
def test_default_model_fn_no_model_file (inference_handler , listdir_return_value ):
85
- with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
91
+ with mock .patch (
92
+ "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
93
+ ) as mock_os :
86
94
mock_os .getenv .return_value = "false"
87
95
mock_os .path .join .return_value = "model_dir"
96
+ mock_os .path .exists .return_value = False
88
97
mock_os .path .isfile .return_value = True
89
98
mock_os .listdir .return_value = listdir_return_value
99
+ mock_os .path .splitext = os .path .splitext
90
100
with mock .patch ("torch.jit.load" ) as mock_torch_load :
91
101
mock_torch_load .return_value = DummyModel ()
92
- with pytest .raises (ValueError ):
102
+ with pytest .raises (
103
+ ValueError , match = r"Exactly one .pth or .pt file is required for PyTorch models: .*"
104
+ ):
93
105
model = inference_handler .default_model_fn ("model_dir" )
94
- assert model is not None
95
106
96
107
97
108
def test_default_input_fn_json (inference_handler , tensor ):
@@ -220,7 +231,9 @@ def test_default_output_fn_gpu(inference_handler):
220
231
221
232
222
233
def test_eia_default_model_fn (eia_inference_handler ):
223
- with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
234
+ with mock .patch (
235
+ "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
236
+ ) as mock_os :
224
237
mock_os .getenv .return_value = "true"
225
238
mock_os .path .join .return_value = "model_dir"
226
239
mock_os .path .exists .return_value = True
@@ -231,7 +244,9 @@ def test_eia_default_model_fn(eia_inference_handler):
231
244
232
245
233
246
def test_eia_default_model_fn_error (eia_inference_handler ):
234
- with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
247
+ with mock .patch (
248
+ "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
249
+ ) as mock_os :
235
250
mock_os .getenv .return_value = "true"
236
251
mock_os .path .join .return_value = "model_dir"
237
252
mock_os .path .exists .return_value = False
@@ -241,7 +256,9 @@ def test_eia_default_model_fn_error(eia_inference_handler):
241
256
242
257
def test_eia_default_predict_fn (eia_inference_handler , tensor ):
243
258
model = DummyModel ()
244
- with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
259
+ with mock .patch (
260
+ "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
261
+ ) as mock_os :
245
262
mock_os .getenv .return_value = "true"
246
263
with mock .patch ("torch.jit.optimized_execution" ) as mock_torch :
247
264
mock_torch .__enter__ .return_value = "dummy"
0 commit comments