31
31
32
32
33
33
class DummyModel (nn .Module ):
34
- def __init__ (self ):
34
+
35
+ def __init__ (self , ):
35
36
super (DummyModel , self ).__init__ ()
36
37
37
38
def forward (self , x ):
@@ -58,11 +59,9 @@ def eia_inference_handler():
58
59
59
60
60
61
def test_default_model_fn (inference_handler ):
61
- with mock .patch (
62
- "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
63
- ) as mock_os :
62
+ with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
64
63
mock_os .getenv .return_value = "true"
65
- mock_os .path .join . return_value = "model_dir"
64
+ mock_os .path .join = os . path . join
66
65
mock_os .path .exists .return_value = True
67
66
with mock .patch ("torch.jit.load" ) as mock_torch :
68
67
mock_torch .return_value = DummyModel ()
@@ -71,11 +70,9 @@ def test_default_model_fn(inference_handler):
71
70
72
71
73
72
def test_default_model_fn_unknown_name (inference_handler ):
74
- with mock .patch (
75
- "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
76
- ) as mock_os :
73
+ with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
77
74
mock_os .getenv .return_value = "false"
78
- mock_os .path .join . return_value = "model_dir"
75
+ mock_os .path .join = os . path . join
79
76
mock_os .path .isfile .return_value = True
80
77
mock_os .listdir .return_value = ["abcd.pt" , "efgh.txt" , "ijkl.bin" ]
81
78
with mock .patch ("torch.jit.load" ) as mock_torch_load :
@@ -99,9 +96,7 @@ def test_default_model_fn_no_model_file(inference_handler, listdir_return_value)
99
96
mock_os .path .splitext = os .path .splitext
100
97
with mock .patch ("torch.jit.load" ) as mock_torch_load :
101
98
mock_torch_load .return_value = DummyModel ()
102
- with pytest .raises (
103
- ValueError , match = r"Exactly one .pth or .pt file is required for PyTorch models: .*"
104
- ):
99
+ with pytest .raises (ValueError , match = r"Exactly one .pth or .pt file is required for PyTorch models: .*" ):
105
100
inference_handler .default_model_fn ("model_dir" )
106
101
107
102
@@ -231,9 +226,7 @@ def test_default_output_fn_gpu(inference_handler):
231
226
232
227
233
228
def test_eia_default_model_fn (eia_inference_handler ):
234
- with mock .patch (
235
- "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
236
- ) as mock_os :
229
+ with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
237
230
mock_os .getenv .return_value = "true"
238
231
mock_os .path .join .return_value = "model_dir"
239
232
mock_os .path .exists .return_value = True
@@ -244,9 +237,7 @@ def test_eia_default_model_fn(eia_inference_handler):
244
237
245
238
246
239
def test_eia_default_model_fn_error (eia_inference_handler ):
247
- with mock .patch (
248
- "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
249
- ) as mock_os :
240
+ with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
250
241
mock_os .getenv .return_value = "true"
251
242
mock_os .path .join .return_value = "model_dir"
252
243
mock_os .path .exists .return_value = False
@@ -256,9 +247,7 @@ def test_eia_default_model_fn_error(eia_inference_handler):
256
247
257
248
def test_eia_default_predict_fn (eia_inference_handler , tensor ):
258
249
model = DummyModel ()
259
- with mock .patch (
260
- "sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
261
- ) as mock_os :
250
+ with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
262
251
mock_os .getenv .return_value = "true"
263
252
with mock .patch ("torch.jit.optimized_execution" ) as mock_torch :
264
253
mock_torch .__enter__ .return_value = "dummy"
0 commit comments