Skip to content

Commit 945c677

Browse files
committed
Fix unit tests
1 parent 94dd4d5 commit 945c677

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

test/unit/test_default_inference_handler.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import csv
1616
import json
17+
import os
1718

1819
import mock
1920
import numpy as np
@@ -30,8 +31,7 @@
3031

3132

3233
class DummyModel(nn.Module):
33-
34-
def __init__(self, ):
34+
def __init__(self):
3535
super(DummyModel, self).__init__()
3636

3737
def forward(self, x):
@@ -58,7 +58,9 @@ def eia_inference_handler():
5858

5959

6060
def test_default_model_fn(inference_handler):
61-
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
61+
with mock.patch(
62+
"sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
63+
) as mock_os:
6264
mock_os.getenv.return_value = "true"
6365
mock_os.path.join.return_value = "model_dir"
6466
mock_os.path.exists.return_value = True
@@ -69,7 +71,9 @@ def test_default_model_fn(inference_handler):
6971

7072

7173
def test_default_model_fn_unknown_name(inference_handler):
72-
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
74+
with mock.patch(
75+
"sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
76+
) as mock_os:
7377
mock_os.getenv.return_value = "false"
7478
mock_os.path.join.return_value = "model_dir"
7579
mock_os.path.isfile.return_value = True
@@ -80,18 +84,25 @@ def test_default_model_fn_unknown_name(inference_handler):
8084
assert model is not None
8185

8286

83-
@pytest.mark.parametrize("listdir_return_value", [["abcd.py", "efgh.txt", "ijkl.bin"], ["abcd.pt", "efgh.pth"]])
87+
@pytest.mark.parametrize(
88+
"listdir_return_value", [["abcd.py", "efgh.txt", "ijkl.bin"], ["abcd.pt", "efgh.pth"]]
89+
)
8490
def test_default_model_fn_no_model_file(inference_handler, listdir_return_value):
85-
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
91+
with mock.patch(
92+
"sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
93+
) as mock_os:
8694
mock_os.getenv.return_value = "false"
8795
mock_os.path.join.return_value = "model_dir"
96+
mock_os.path.exists.return_value = False
8897
mock_os.path.isfile.return_value = True
8998
mock_os.listdir.return_value = listdir_return_value
99+
mock_os.path.splitext = os.path.splitext
90100
with mock.patch("torch.jit.load") as mock_torch_load:
91101
mock_torch_load.return_value = DummyModel()
92-
with pytest.raises(ValueError):
102+
with pytest.raises(
103+
ValueError, match=r"Exactly one .pth or .pt file is required for PyTorch models: .*"
104+
):
93105
model = inference_handler.default_model_fn("model_dir")
94-
assert model is not None
95106

96107

97108
def test_default_input_fn_json(inference_handler, tensor):
@@ -220,7 +231,9 @@ def test_default_output_fn_gpu(inference_handler):
220231

221232

222233
def test_eia_default_model_fn(eia_inference_handler):
223-
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
234+
with mock.patch(
235+
"sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
236+
) as mock_os:
224237
mock_os.getenv.return_value = "true"
225238
mock_os.path.join.return_value = "model_dir"
226239
mock_os.path.exists.return_value = True
@@ -231,7 +244,9 @@ def test_eia_default_model_fn(eia_inference_handler):
231244

232245

233246
def test_eia_default_model_fn_error(eia_inference_handler):
234-
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
247+
with mock.patch(
248+
"sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
249+
) as mock_os:
235250
mock_os.getenv.return_value = "true"
236251
mock_os.path.join.return_value = "model_dir"
237252
mock_os.path.exists.return_value = False
@@ -241,7 +256,9 @@ def test_eia_default_model_fn_error(eia_inference_handler):
241256

242257
def test_eia_default_predict_fn(eia_inference_handler, tensor):
243258
model = DummyModel()
244-
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
259+
with mock.patch(
260+
"sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os"
261+
) as mock_os:
245262
mock_os.getenv.return_value = "true"
246263
with mock.patch("torch.jit.optimized_execution") as mock_torch:
247264
mock_torch.__enter__.return_value = "dummy"

0 commit comments

Comments
 (0)