Skip to content

Add functionality to configure TorchServe logging levels using the TS_LOG_LEVEL environment variable. #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions src/sagemaker_pytorch_serving_container/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os,sys
import subprocess
from subprocess import CalledProcessError

from retrying import retry
Expand All @@ -20,11 +22,49 @@

HANDLER_SERVICE = handler_service.__file__

## added logging function to configure log4j2 loglevel.
def configure_logging():
log_levels = {
'0': 'off',
'10': 'fatal',
'20': 'error',
'30': 'warn',
'40': 'info',
'50': 'debug',
'60': 'trace'
}

# Get the directory of the current script
current_script_path = os.path.abspath(__file__)

# Construct the path to log4j2.xml relative to the script location
log4j2_path = os.path.join(os.path.dirname(current_script_path), 'etc', 'log4j2.xml')

print(f"Current script path: {current_script_path}")
print(f"log4j2.xml path: {log4j2_path}")

if not os.path.exists(log4j2_path):
print(f"Error: {log4j2_path} does not exist", file=sys.stderr)
return

ts_log_level = os.environ.get('TS_LOG_LEVEL')

if ts_log_level is not None:
if ts_log_level in log_levels:
try:
log_level = log_levels[ts_log_level]
subprocess.run(['sed', '-i', f's/info/{log_level}/g', log4j2_path], check=True)
print(f"Logging level set to {log_level}")
except subprocess.CalledProcessError as e:
print(f"Error configuring the logging: {e}", file=sys.stderr)
else:
print(f"Invalid TS_LOG_LEVEL value: {ts_log_level}. No changes made to logging configuration.", file=sys.stderr)
else:
print("TS_LOG_LEVEL not set. Using default logging configuration.")

def _retry_if_error(exception):
return isinstance(exception, CalledProcessError)


@retry(stop_max_delay=1000 * 30,
retry_on_exception=_retry_if_error)
def _start_torchserve():
Expand All @@ -33,6 +73,9 @@ def _start_torchserve():
# retry starting mms until it's ready
torchserve.start_torchserve(handler_service=HANDLER_SERVICE)


def main():
configure_logging()
_start_torchserve()

if __name__ == '__main__':
main()
68 changes: 68 additions & 0 deletions test/unit/test_log_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import unittest
from unittest.mock import patch, MagicMock
import os, subprocess
import sys
import io

# Add the src directory to the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'src')))

from sagemaker_pytorch_serving_container.serving import configure_logging

class TestLogConfig(unittest.TestCase):
@patch('os.path.exists')
@patch('os.environ.get')
@patch('subprocess.run')
def test_valid_log_level(self, mock_run, mock_env_get, mock_exists):
mock_exists.return_value = True
mock_env_get.return_value = '20'
mock_run.return_value = MagicMock(returncode=0)

with patch('sys.stdout', new=io.StringIO()) as fake_out:
configure_logging()
self.assertIn("Logging level set to error", fake_out.getvalue())

mock_run.assert_called_once()

@patch('os.path.exists')
@patch('os.environ.get')
def test_invalid_log_level(self, mock_env_get, mock_exists):
mock_exists.return_value = True
mock_env_get.return_value = '70'

with patch('sys.stderr', new=io.StringIO()) as fake_err:
configure_logging()
self.assertIn("Invalid TS_LOG_LEVEL value: 70", fake_err.getvalue())

@patch('os.path.exists')
@patch('os.environ.get')
def test_no_log_level_set(self, mock_env_get, mock_exists):
mock_exists.return_value = True
mock_env_get.return_value = None

with patch('sys.stdout', new=io.StringIO()) as fake_out:
configure_logging()
self.assertIn("TS_LOG_LEVEL not set", fake_out.getvalue())

@patch('os.path.exists')
@patch('os.environ.get')
@patch('subprocess.run')
def test_subprocess_error(self, mock_run, mock_env_get, mock_exists):
mock_exists.return_value = True
mock_env_get.return_value = '20'
mock_run.side_effect = subprocess.CalledProcessError(1, 'sed')

with patch('sys.stderr', new=io.StringIO()) as fake_err:
configure_logging()
self.assertIn("Error configuring the logging", fake_err.getvalue())

@patch('os.path.exists')
def test_log4j2_file_not_found(self, mock_exists):
mock_exists.return_value = False

with patch('sys.stderr', new=io.StringIO()) as fake_err:
configure_logging()
self.assertIn("does not exist", fake_err.getvalue())

if __name__ == '__main__':
unittest.main()