diff --git a/src/sagemaker_pytorch_serving_container/torchserve.py b/src/sagemaker_pytorch_serving_container/torchserve.py index d90e0000..c9b01ee8 100644 --- a/src/sagemaker_pytorch_serving_container/torchserve.py +++ b/src/sagemaker_pytorch_serving_container/torchserve.py @@ -50,6 +50,7 @@ PYTHON_PATH_ENV = "PYTHONPATH" REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt") +LOG4J_OVERRIDE_PATH = os.path.join(code_dir, "log4j.xml") TS_NAMESPACE = "org.pytorch.serve.ModelServer" @@ -81,6 +82,11 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE): if os.path.exists(REQUIREMENTS_PATH): _install_requirements() + if os.path.exists(LOG4J_OVERRIDE_PATH): + log4j_path = LOG4J_OVERRIDE_PATH + else: + log4j_path = DEFAULT_TS_LOG_FILE + ts_torchserve_cmd = [ "torchserve", "--start", @@ -89,7 +95,7 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE): "--ts-config", TS_CONFIG_FILE, "--log-config", - DEFAULT_TS_LOG_FILE, + log4j_path, "--models", "model.mar" ] diff --git a/test/unit/test_model_server.py b/test/unit/test_model_server.py index aeaec28e..10598af8 100644 --- a/test/unit/test_model_server.py +++ b/test/unit/test_model_server.py @@ -21,7 +21,9 @@ from sagemaker_inference import environment from sagemaker_pytorch_serving_container import torchserve -from sagemaker_pytorch_serving_container.torchserve import TS_NAMESPACE, REQUIREMENTS_PATH +from sagemaker_pytorch_serving_container.torchserve import ( + TS_NAMESPACE, REQUIREMENTS_PATH, LOG4J_OVERRIDE_PATH +) PYTHON_PATH = "python_path" DEFAULT_CONFIGURATION = "default_configuration" @@ -32,7 +34,7 @@ @patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process") @patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler") @patch("sagemaker_pytorch_serving_container.torchserve._install_requirements") -@patch("os.path.exists", return_value=True) +@patch("os.path.exists", side_effect=[True, False]) @patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file") @patch("sagemaker_pytorch_serving_container.torchserve._adapt_to_ts_format") def test_start_torchserve_default_service_handler( @@ -49,7 +51,8 @@ def test_start_torchserve_default_service_handler( adapt.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE) create_config.assert_called_once_with() - exists.assert_called_once_with(REQUIREMENTS_PATH) + exists.assert_any_call(REQUIREMENTS_PATH) + exists.assert_any_call(LOG4J_OVERRIDE_PATH) install_requirements.assert_called_once_with() ts_model_server_cmd = [ @@ -74,7 +77,7 @@ def test_start_torchserve_default_service_handler( @patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process") @patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler") @patch("sagemaker_pytorch_serving_container.torchserve._install_requirements") -@patch("os.path.exists", return_value=True) +@patch("os.path.exists", side_effect=[True, False]) @patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file") @patch("sagemaker_pytorch_serving_container.torchserve._adapt_to_ts_format") def test_start_torchserve_default_service_handler_multi_model( @@ -91,7 +94,8 @@ def test_start_torchserve_default_service_handler_multi_model( torchserve.start_torchserve() torchserve.ENABLE_MULTI_MODEL = False create_config.assert_called_once_with() - exists.assert_called_once_with(REQUIREMENTS_PATH) + exists.assert_any_call(REQUIREMENTS_PATH) + exists.assert_any_call(LOG4J_OVERRIDE_PATH) install_requirements.assert_called_once_with() ts_model_server_cmd = [