Skip to content

Commit 900b406

Browse files
committed
PR Feedback - Replace MMS with TS
1 parent 8225190 commit 900b406

11 files changed

+680
-45
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def read(fname):
3131

3232
packages=find_packages(where='src', exclude=('test',)),
3333
package_dir={'': 'src'},
34+
package_data={'': ["etc/*"]},
3435
py_modules=[splitext(basename(path))[0] for path in glob('src/*.py')],
3536

3637
long_description=read('README.rst'),
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Based on: https://github.com/awslabs/mxnet-model-server/blob/master/docs/configuration.md
2+
enable_envvars_config=true
3+
decode_input_request=false
4+
load_models=ALL
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
log4j.rootLogger = INFO, console
2+
3+
log4j.appender.console = org.apache.log4j.ConsoleAppender
4+
log4j.appender.console.Target = System.out
5+
log4j.appender.console.layout = org.apache.log4j.PatternLayout
6+
log4j.appender.console.layout.ConversionPattern = %d{ISO8601} [%-5p] %t %c - %m%n
7+
8+
log4j.appender.access_log = org.apache.log4j.RollingFileAppender
9+
log4j.appender.access_log.File = ${LOG_LOCATION}/access_log.log
10+
log4j.appender.access_log.MaxFileSize = 10MB
11+
log4j.appender.access_log.MaxBackupIndex = 5
12+
log4j.appender.access_log.layout = org.apache.log4j.PatternLayout
13+
log4j.appender.access_log.layout.ConversionPattern = %d{ISO8601} - %m%n
14+
15+
log4j.appender.ts_log = org.apache.log4j.RollingFileAppender
16+
log4j.appender.ts_log.File = ${LOG_LOCATION}/ts_log.log
17+
log4j.appender.ts_log.MaxFileSize = 10MB
18+
log4j.appender.ts_log.MaxBackupIndex = 5
19+
log4j.appender.ts_log.layout = org.apache.log4j.PatternLayout
20+
log4j.appender.ts_log.layout.ConversionPattern = %d{ISO8601} [%-5p] %t %c - %m%n
21+
22+
log4j.appender.ts_metrics = org.apache.log4j.RollingFileAppender
23+
log4j.appender.ts_metrics.File = ${METRICS_LOCATION}/ts_metrics.log
24+
log4j.appender.ts_metrics.MaxFileSize = 10MB
25+
log4j.appender.ts_metrics.MaxBackupIndex = 5
26+
log4j.appender.ts_metrics.layout = org.apache.log4j.PatternLayout
27+
log4j.appender.ts_metrics.layout.ConversionPattern = %d{ISO8601} - %m%n
28+
29+
log4j.appender.model_log = org.apache.log4j.RollingFileAppender
30+
log4j.appender.model_log.File = ${LOG_LOCATION}/model_log.log
31+
log4j.appender.model_log.MaxFileSize = 10MB
32+
log4j.appender.model_log.MaxBackupIndex = 5
33+
log4j.appender.model_log.layout = org.apache.log4j.PatternLayout
34+
log4j.appender.model_log.layout.ConversionPattern = %d{ISO8601} [%-5p] %c - %m%n
35+
36+
log4j.appender.model_metrics = org.apache.log4j.RollingFileAppender
37+
log4j.appender.model_metrics.File = ${METRICS_LOCATION}/model_metrics.log
38+
log4j.appender.model_metrics.MaxFileSize = 10MB
39+
log4j.appender.model_metrics.MaxBackupIndex = 5
40+
log4j.appender.model_metrics.layout = org.apache.log4j.PatternLayout
41+
log4j.appender.model_metrics.layout.ConversionPattern = %d{ISO8601} - %m%n
42+
43+
log4j.logger.com.amazonaws.ml.ts = INFO, ts_log
44+
log4j.logger.ACCESS_LOG = INFO, access_log
45+
log4j.logger.TS_METRICS = INFO, ts_metrics
46+
log4j.logger.MODEL_METRICS = INFO, model_metrics
47+
log4j.logger.MODEL_LOG = INFO, model_log
48+
49+
log4j.logger.org.apache = OFF
50+
log4j.logger.io.netty = ERROR

src/sagemaker_pytorch_serving_container/handler_service.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
from sagemaker_inference.default_handler_service import DefaultHandlerService
1616
from sagemaker_inference.transformer import Transformer
17-
from sagemaker_pytorch_serving_container.default_inference_handler import \
18-
DefaultPytorchInferenceHandler
17+
from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler
1918

2019
import os
2120
import sys

src/sagemaker_pytorch_serving_container/serving.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from subprocess import CalledProcessError
1616

1717
from retrying import retry
18-
from sagemaker_inference import torchserve
19-
18+
from sagemaker_pytorch_serving_container import torchserve
2019
from sagemaker_pytorch_serving_container import handler_service
2120

2221
HANDLER_SERVICE = handler_service.__file__
@@ -28,12 +27,12 @@ def _retry_if_error(exception):
2827

2928
@retry(stop_max_delay=1000 * 30,
3029
retry_on_exception=_retry_if_error)
31-
def _start_model_server():
30+
def _start_torchserve():
3231
# there's a race condition that causes the model server command to
3332
# sometimes fail with 'bad address'. more investigation needed
3433
# retry starting mms until it's ready
35-
torchserve.start_model_server(handler_service=HANDLER_SERVICE)
34+
torchserve.start_torchserve(handler_service=HANDLER_SERVICE)
3635

3736

3837
def main():
39-
_start_model_server()
38+
_start_torchserve()
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functionality to configure and start the
14+
multi-model server."""
15+
from __future__ import absolute_import
16+
17+
import os
18+
import signal
19+
import subprocess
20+
import sys
21+
22+
import pkg_resources
23+
import psutil
24+
from retrying import retry
25+
26+
import sagemaker_pytorch_serving_container
27+
from sagemaker_inference import default_handler_service, environment, logging, utils
28+
from sagemaker_inference.environment import code_dir
29+
30+
logger = logging.get_logger()
31+
32+
TS_CONFIG_FILE = os.path.join("/etc", "sagemaker-ts.properties")
33+
DEFAULT_HANDLER_SERVICE = default_handler_service.__name__
34+
DEFAULT_TS_CONFIG_FILE = pkg_resources.resource_filename(
35+
sagemaker_pytorch_serving_container.__name__, "/etc/default-ts.properties"
36+
)
37+
MME_TS_CONFIG_FILE = pkg_resources.resource_filename(
38+
sagemaker_pytorch_serving_container.__name__, "/etc/mme-ts.properties"
39+
)
40+
DEFAULT_TS_LOG_FILE = pkg_resources.resource_filename(
41+
sagemaker_pytorch_serving_container.__name__, "/etc/log4j.properties"
42+
)
43+
DEFAULT_TS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/ts/models")
44+
DEFAULT_TS_MODEL_NAME = "model"
45+
DEFAULT_TS_MODEL_SERIALIZED_FILE = "model.pth"
46+
DEFAULT_HANDLER_SERVICE = "sagemaker_pytorch_serving_container.handler_service"
47+
48+
ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
49+
MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_TS_MODEL_DIRECTORY
50+
51+
PYTHON_PATH_ENV = "PYTHONPATH"
52+
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")
53+
TS_NAMESPACE = "org.pytorch.serve.ModelServer"
54+
55+
56+
def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
57+
"""Configure and start the model server.
58+
59+
Args:
60+
handler_service (str): python path pointing to a module that defines
61+
a class with the following:
62+
63+
- A ``handle`` method, which is invoked for all incoming inference
64+
requests to the model server.
65+
- A ``initialize`` method, which is invoked at model server start up
66+
for loading the model.
67+
68+
Defaults to ``sagemaker_pytorch_serving_container.default_handler_service``.
69+
70+
"""
71+
72+
if ENABLE_MULTI_MODEL:
73+
if not os.getenv("SAGEMAKER_HANDLER"):
74+
os.environ["SAGEMAKER_HANDLER"] = handler_service
75+
_set_python_path()
76+
else:
77+
_adapt_to_ts_format(handler_service)
78+
79+
_create_torchserve_config_file()
80+
81+
if os.path.exists(REQUIREMENTS_PATH):
82+
_install_requirements()
83+
84+
ts_torchserve_cmd = [
85+
"torchserve",
86+
"--start",
87+
"--model-store",
88+
MODEL_STORE,
89+
"--ts-config",
90+
TS_CONFIG_FILE,
91+
"--log-config",
92+
DEFAULT_TS_LOG_FILE,
93+
"--models",
94+
"model.mar"
95+
]
96+
97+
print(ts_torchserve_cmd)
98+
99+
logger.info(ts_torchserve_cmd)
100+
subprocess.Popen(ts_torchserve_cmd)
101+
102+
ts_process = _retrieve_ts_server_process()
103+
104+
_add_sigterm_handler(ts_process)
105+
106+
ts_process.wait()
107+
108+
109+
def _adapt_to_ts_format(handler_service):
110+
if not os.path.exists(DEFAULT_TS_MODEL_DIRECTORY):
111+
os.makedirs(DEFAULT_TS_MODEL_DIRECTORY)
112+
113+
model_archiver_cmd = [
114+
"torch-model-archiver",
115+
"--model-name",
116+
DEFAULT_TS_MODEL_NAME,
117+
"--handler",
118+
handler_service,
119+
"--serialized-file",
120+
os.path.join(environment.model_dir, DEFAULT_TS_MODEL_SERIALIZED_FILE),
121+
"--export-path",
122+
DEFAULT_TS_MODEL_DIRECTORY,
123+
"--extra-files",
124+
os.path.join(environment.model_dir, environment.Environment().module_name + ".py"),
125+
"--version",
126+
"1",
127+
]
128+
129+
logger.info(model_archiver_cmd)
130+
subprocess.check_call(model_archiver_cmd)
131+
132+
_set_python_path()
133+
134+
135+
def _set_python_path():
136+
# Torchserve handles code execution by appending the export path, provided
137+
# to the model archiver, to the PYTHONPATH env var.
138+
# The code_dir has to be added to the PYTHONPATH otherwise the
139+
# user provided module can not be imported properly.
140+
code_dir_path = "{}:".format(environment.code_dir)
141+
142+
if PYTHON_PATH_ENV in os.environ:
143+
os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV]
144+
else:
145+
os.environ[PYTHON_PATH_ENV] = code_dir_path
146+
147+
148+
def _create_torchserve_config_file():
149+
configuration_properties = _generate_ts_config_properties()
150+
151+
utils.write_file(TS_CONFIG_FILE, configuration_properties)
152+
153+
154+
def _generate_ts_config_properties():
155+
env = environment.Environment()
156+
157+
user_defined_configuration = {
158+
"default_response_timeout": env.torchserve_timeout,
159+
"default_workers_per_model": env.torchserve_workers,
160+
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
161+
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
162+
}
163+
164+
custom_configuration = str()
165+
166+
for key in user_defined_configuration:
167+
value = user_defined_configuration.get(key)
168+
if value:
169+
custom_configuration += "{}={}\n".format(key, value)
170+
171+
if ENABLE_MULTI_MODEL:
172+
default_configuration = utils.read_file(MME_TS_CONFIG_FILE)
173+
else:
174+
default_configuration = utils.read_file(DEFAULT_TS_CONFIG_FILE)
175+
176+
return default_configuration + custom_configuration
177+
178+
179+
def _add_sigterm_handler(ts_process):
180+
def _terminate(signo, frame): # pylint: disable=unused-argument
181+
try:
182+
os.kill(ts_process.pid, signal.SIGTERM)
183+
except OSError:
184+
pass
185+
186+
signal.signal(signal.SIGTERM, _terminate)
187+
188+
189+
def _install_requirements():
190+
logger.info("installing packages from requirements.txt...")
191+
pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH]
192+
193+
try:
194+
subprocess.check_call(pip_install_cmd)
195+
except subprocess.CalledProcessError:
196+
logger.error("failed to install required packages, exiting")
197+
raise ValueError("failed to install required packages")
198+
199+
200+
# retry for 10 seconds
201+
@retry(stop_max_delay=10 * 1000)
202+
def _retrieve_ts_server_process():
203+
ts_server_processes = list()
204+
205+
for process in psutil.process_iter():
206+
if TS_NAMESPACE in process.cmdline():
207+
ts_server_processes.append(process)
208+
209+
if not ts_server_processes:
210+
raise Exception("ts model server was unsuccessfully started")
211+
212+
if len(ts_server_processes) > 1:
213+
raise Exception("multiple ts model servers are not supported")
214+
215+
return ts_server_processes[0]

test/unit/test_default_inference_handler.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from six import StringIO, BytesIO
2525
from torch.autograd import Variable
2626

27-
from sagemaker_pytorch_serving_container import default_inference_handler
27+
from sagemaker_pytorch_serving_container import default_pytorch_inference_handler
2828

2929
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3030

@@ -49,12 +49,12 @@ def fixture_tensor():
4949

5050
@pytest.fixture()
5151
def inference_handler():
52-
return default_inference_handler.DefaultPytorchInferenceHandler()
52+
return default_pytorch_inference_handler.DefaultPytorchInferenceHandler()
5353

5454

5555
@pytest.fixture()
5656
def eia_inference_handler():
57-
return default_inference_handler.DefaultPytorchInferenceHandler()
57+
return default_pytorch_inference_handler.DefaultPytorchInferenceHandler()
5858

5959

6060
def test_default_model_fn(inference_handler):
@@ -178,33 +178,3 @@ def test_default_output_fn_gpu(inference_handler):
178178
output = inference_handler.default_output_fn(tensor_gpu, content_types.CSV)
179179

180180
assert "1,2,3\n4,5,6\n".encode("utf-8") == output
181-
182-
183-
def test_eia_default_model_fn(eia_inference_handler):
184-
with mock.patch("sagemaker_pytorch_serving_container.default_inference_handler.os") as mock_os:
185-
mock_os.getenv.return_value = "true"
186-
mock_os.path.join.return_value = "model_dir"
187-
mock_os.path.exists.return_value = True
188-
with mock.patch("torch.jit.load") as mock_torch:
189-
mock_torch.return_value = DummyModel()
190-
model = eia_inference_handler.default_model_fn("model_dir")
191-
assert model is not None
192-
193-
194-
def test_eia_default_model_fn_error(eia_inference_handler):
195-
with mock.patch("sagemaker_pytorch_serving_container.default_inference_handler.os") as mock_os:
196-
mock_os.getenv.return_value = "true"
197-
mock_os.path.join.return_value = "model_dir"
198-
mock_os.path.exists.return_value = False
199-
with pytest.raises(FileNotFoundError):
200-
eia_inference_handler.default_model_fn("model_dir")
201-
202-
203-
def test_eia_default_predict_fn(eia_inference_handler, tensor):
204-
model = DummyModel()
205-
with mock.patch("sagemaker_pytorch_serving_container.default_inference_handler.os") as mock_os:
206-
mock_os.getenv.return_value = "true"
207-
with mock.patch("torch.jit.optimized_execution") as mock_torch:
208-
mock_torch.__enter__.return_value = "dummy"
209-
eia_inference_handler.default_predict_fn(tensor, model)
210-
mock_torch.assert_called_once()

0 commit comments

Comments
 (0)