Skip to content

Commit 94dd4d5

Browse files
committed
Add unit tests for unknown model names
1 parent 414c83a commit 94dd4d5

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

test/unit/test_default_inference_handler.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,32 @@ def test_default_model_fn(inference_handler):
6868
assert model is not None
6969

7070

71+
def test_default_model_fn_unknown_name(inference_handler):
72+
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
73+
mock_os.getenv.return_value = "false"
74+
mock_os.path.join.return_value = "model_dir"
75+
mock_os.path.isfile.return_value = True
76+
mock_os.listdir.return_value = ["abcd.pt", "efgh.txt", "ijkl.bin"]
77+
with mock.patch("torch.jit.load") as mock_torch_load:
78+
mock_torch_load.return_value = DummyModel()
79+
model = inference_handler.default_model_fn("model_dir")
80+
assert model is not None
81+
82+
83+
@pytest.mark.parametrize("listdir_return_value", [["abcd.py", "efgh.txt", "ijkl.bin"], ["abcd.pt", "efgh.pth"]])
84+
def test_default_model_fn_no_model_file(inference_handler, listdir_return_value):
85+
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
86+
mock_os.getenv.return_value = "false"
87+
mock_os.path.join.return_value = "model_dir"
88+
mock_os.path.isfile.return_value = True
89+
mock_os.listdir.return_value = listdir_return_value
90+
with mock.patch("torch.jit.load") as mock_torch_load:
91+
mock_torch_load.return_value = DummyModel()
92+
with pytest.raises(ValueError):
93+
model = inference_handler.default_model_fn("model_dir")
94+
assert model is not None
95+
96+
7197
def test_default_input_fn_json(inference_handler, tensor):
7298
json_data = json.dumps(tensor.cpu().numpy().tolist())
7399
deserialized_np_array = inference_handler.default_input_fn(json_data, content_types.JSON)

0 commit comments

Comments
 (0)