@@ -68,6 +68,32 @@ def test_default_model_fn(inference_handler):
68
68
assert model is not None
69
69
70
70
71
+ def test_default_model_fn_unknown_name (inference_handler ):
72
+ with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
73
+ mock_os .getenv .return_value = "false"
74
+ mock_os .path .join .return_value = "model_dir"
75
+ mock_os .path .isfile .return_value = True
76
+ mock_os .listdir .return_value = ["abcd.pt" , "efgh.txt" , "ijkl.bin" ]
77
+ with mock .patch ("torch.jit.load" ) as mock_torch_load :
78
+ mock_torch_load .return_value = DummyModel ()
79
+ model = inference_handler .default_model_fn ("model_dir" )
80
+ assert model is not None
81
+
82
+
83
+ @pytest .mark .parametrize ("listdir_return_value" , [["abcd.py" , "efgh.txt" , "ijkl.bin" ], ["abcd.pt" , "efgh.pth" ]])
84
+ def test_default_model_fn_no_model_file (inference_handler , listdir_return_value ):
85
+ with mock .patch ("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os" ) as mock_os :
86
+ mock_os .getenv .return_value = "false"
87
+ mock_os .path .join .return_value = "model_dir"
88
+ mock_os .path .isfile .return_value = True
89
+ mock_os .listdir .return_value = listdir_return_value
90
+ with mock .patch ("torch.jit.load" ) as mock_torch_load :
91
+ mock_torch_load .return_value = DummyModel ()
92
+ with pytest .raises (ValueError ):
93
+ model = inference_handler .default_model_fn ("model_dir" )
94
+ assert model is not None
95
+
96
+
71
97
def test_default_input_fn_json (inference_handler , tensor ):
72
98
json_data = json .dumps (tensor .cpu ().numpy ().tolist ())
73
99
deserialized_np_array = inference_handler .default_input_fn (json_data , content_types .JSON )
0 commit comments