Skip to content

Commit 3fa2dea

Browse files
authored
pass model directory as input to torchserve (#118)
* update torchserve * remove repackaging fn * update assert check * update torchserve in general dockerfile * test with dlc dockerfiles * uninstall auto confirmation * uninstall auto confirmation * run only gpu tests * run all integ tests * set default service handler in ts config file * test * test * revert passing handler service to ts config * Revert "revert passing handler service to ts config" This reverts commit d62f5ff. * add pytest logs * build/push dockerfile * pass handler fn * skip unit test * add logging to sm tests * test * test * test * fix flake8 * fix unit test * test gpu sm generic * skip sm integration tests with generic image * test generic image * enable all tests
1 parent 17613f1 commit 3fa2dea

File tree

6 files changed

+38
-116
lines changed

6 files changed

+38
-116
lines changed

buildspec.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ phases:
4242
- DLC_EIA_TAG="$EIA_FRAMEWORK_VERSION-dlc-eia-$BUILD_ID"
4343

4444
# run local CPU integration tests (build and push the image to ECR repo)
45-
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local --build-image --push-image --dockerfile-type pytorch --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --aws-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor cpu --tag $GENERIC_TAG"
45+
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local -vv -rA -s --build-image --push-image --dockerfile-type pytorch --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --aws-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor cpu --tag $GENERIC_TAG"
4646
- execute-command-if-has-matching-changes "$test_cmd" "test/" "src/*.py" "setup.py" "setup.cfg" "buildspec-toolkit.yml" "artifacts/*"
47-
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local --build-image --push-image --dockerfile-type dlc.cpu --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --aws-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor cpu --tag $DLC_CPU_TAG"
47+
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local -vv -rA -s --build-image --push-image --dockerfile-type dlc.cpu --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --aws-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor cpu --tag $DLC_CPU_TAG"
4848
- execute-command-if-has-matching-changes "$test_cmd" "test/" "src/*.py" "setup.py" "setup.cfg" "buildspec-toolkit.yml" "artifacts/*"
4949

5050
# launch remote GPU instance
@@ -65,10 +65,10 @@ phases:
6565
# run GPU local integration tests
6666
- printf "$SETUP_CMDS" > $SETUP_FILE
6767
# no reason to rebuild the image again since it was already built and pushed to ECR during CPU tests
68-
- generic_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --aws-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --tag $GENERIC_TAG"
68+
- generic_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local -vv -rA -s --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --aws-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --tag $GENERIC_TAG"
6969
- test_cmd="remote-test --github-repo $GITHUB_REPO --test-cmd \"$generic_cmd\" --setup-file $SETUP_FILE --pr-number \"$PR_NUM\""
7070
- execute-command-if-has-matching-changes "$test_cmd" "test/" "src/*.py" "setup.py" "setup.cfg" "buildspec-toolkit.yml" "artifacts/*"
71-
- dlc_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --aws-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --tag $DLC_GPU_TAG"
71+
- dlc_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local -vv -rA -s --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --aws-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --tag $DLC_GPU_TAG"
7272
- test_cmd="remote-test --github-repo $GITHUB_REPO --test-cmd \"$dlc_cmd\" --setup-file $SETUP_FILE --pr-number \"$PR_NUM\" --skip-setup"
7373
- execute-command-if-has-matching-changes "$test_cmd" "test/" "src/*.py" "setup.py" "setup.cfg" "buildspec-toolkit.yml" "artifacts/*"
7474

src/sagemaker_pytorch_serving_container/torchserve.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@
2525

2626
import sagemaker_pytorch_serving_container
2727
from sagemaker_pytorch_serving_container import ts_environment
28-
from sagemaker_inference import default_handler_service, environment, utils
28+
from sagemaker_inference import environment, utils
2929
from sagemaker_inference.environment import code_dir
3030

3131
logger = logging.getLogger()
3232

3333
TS_CONFIG_FILE = os.path.join("/etc", "sagemaker-ts.properties")
34-
DEFAULT_HANDLER_SERVICE = default_handler_service.__name__
3534
DEFAULT_TS_CONFIG_FILE = pkg_resources.resource_filename(
3635
sagemaker_pytorch_serving_container.__name__, "/etc/default-ts.properties"
3736
)
@@ -41,13 +40,11 @@
4140
DEFAULT_TS_LOG_FILE = pkg_resources.resource_filename(
4241
sagemaker_pytorch_serving_container.__name__, "/etc/log4j2.xml"
4342
)
44-
DEFAULT_TS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker", "ts", "models")
4543
DEFAULT_TS_MODEL_NAME = "model"
46-
DEFAULT_TS_CODE_DIR = "code"
4744
DEFAULT_HANDLER_SERVICE = "sagemaker_pytorch_serving_container.handler_service"
4845

4946
ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
50-
MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_TS_MODEL_DIRECTORY
47+
MODEL_STORE = "/" if ENABLE_MULTI_MODEL else os.path.join(os.getcwd(), ".sagemaker", "ts", "models")
5148

5249
PYTHON_PATH_ENV = "PYTHONPATH"
5350
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")
@@ -73,11 +70,13 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
7370
if ENABLE_MULTI_MODEL:
7471
if "SAGEMAKER_HANDLER" not in os.environ:
7572
os.environ["SAGEMAKER_HANDLER"] = handler_service
76-
_set_python_path()
7773
else:
78-
_adapt_to_ts_format(handler_service)
74+
if not os.path.exists(MODEL_STORE):
75+
os.makedirs(MODEL_STORE)
7976

80-
_create_torchserve_config_file()
77+
_set_python_path()
78+
79+
_create_torchserve_config_file(handler_service)
8180

8281
if os.path.exists(REQUIREMENTS_PATH):
8382
_install_requirements()
@@ -92,7 +91,7 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
9291
"--log-config",
9392
DEFAULT_TS_LOG_FILE,
9493
"--models",
95-
"model.mar"
94+
DEFAULT_TS_MODEL_NAME + "=" + environment.model_dir
9695
]
9796

9897
print(ts_torchserve_cmd)
@@ -107,30 +106,6 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
107106
ts_process.wait()
108107

109108

110-
def _adapt_to_ts_format(handler_service):
111-
if not os.path.exists(DEFAULT_TS_MODEL_DIRECTORY):
112-
os.makedirs(DEFAULT_TS_MODEL_DIRECTORY)
113-
114-
model_archiver_cmd = [
115-
"torch-model-archiver",
116-
"--model-name",
117-
DEFAULT_TS_MODEL_NAME,
118-
"--handler",
119-
handler_service,
120-
"--export-path",
121-
DEFAULT_TS_MODEL_DIRECTORY,
122-
"--version",
123-
"1",
124-
"--extra-files",
125-
os.path.join(environment.model_dir)
126-
]
127-
128-
logger.info(model_archiver_cmd)
129-
subprocess.check_call(model_archiver_cmd)
130-
131-
_set_python_path()
132-
133-
134109
def _set_python_path():
135110
# Torchserve handles code execution by appending the export path, provided
136111
# to the model archiver, to the PYTHONPATH env var.
@@ -142,19 +117,20 @@ def _set_python_path():
142117
os.environ[PYTHON_PATH_ENV] = environment.code_dir
143118

144119

145-
def _create_torchserve_config_file():
146-
configuration_properties = _generate_ts_config_properties()
120+
def _create_torchserve_config_file(handler_service):
121+
configuration_properties = _generate_ts_config_properties(handler_service)
147122

148123
utils.write_file(TS_CONFIG_FILE, configuration_properties)
149124

150125

151-
def _generate_ts_config_properties():
126+
def _generate_ts_config_properties(handler_service):
152127
env = environment.Environment()
153128
user_defined_configuration = {
154129
"default_response_timeout": env.model_server_timeout,
155130
"default_workers_per_model": env.model_server_workers,
156131
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
157132
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
133+
"default_service_handler": handler_service + ":handle",
158134
}
159135

160136
ts_env = ts_environment.TorchServeEnvironment()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
ARG region
22
FROM 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-inference:1.10.2-cpu-py38-ubuntu20.04-sagemaker
33

4+
RUN pip uninstall torchserve -y && \
5+
pip install torchserve-nightly==2022.3.23.post2
6+
47
COPY dist/sagemaker_pytorch_inference-*.tar.gz /sagemaker_pytorch_inference.tar.gz
58
RUN pip install --upgrade --no-cache-dir /sagemaker_pytorch_inference.tar.gz && \
69
rm /sagemaker_pytorch_inference.tar.gz
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
ARG region
22
FROM 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-inference:1.10.2-gpu-py38-cu113-ubuntu20.04-sagemaker
33

4+
RUN pip uninstall torchserve -y && \
5+
pip install torchserve-nightly==2022.3.23.post2
6+
47
COPY dist/sagemaker_pytorch_inference-*.tar.gz /sagemaker_pytorch_inference.tar.gz
58
RUN pip install --upgrade --no-cache-dir /sagemaker_pytorch_inference.tar.gz && \
69
rm /sagemaker_pytorch_inference.tar.gz

test/container/1.10.2/Dockerfile.pytorch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ RUN apt-get update \
2525
RUN conda install -c conda-forge opencv \
2626
&& ln -s /opt/conda/bin/pip /usr/local/bin/pip3
2727

28-
RUN pip install torchserve==$TS_VERSION \
28+
RUN pip install torchserve-nightly==2022.3.23.post2 \
2929
&& pip install torch-model-archiver==$TS_ARCHIVER_VERSION
3030

3131
COPY dist/sagemaker_pytorch_inference-*.tar.gz /sagemaker_pytorch_inference.tar.gz

test/unit/test_model_server.py

Lines changed: 15 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
@patch("sagemaker_pytorch_serving_container.torchserve._install_requirements")
3535
@patch("os.path.exists", return_value=True)
3636
@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file")
37-
@patch("sagemaker_pytorch_serving_container.torchserve._adapt_to_ts_format")
37+
@patch("sagemaker_pytorch_serving_container.torchserve._set_python_path")
3838
def test_start_torchserve_default_service_handler(
39-
adapt,
39+
set_python_path,
4040
create_config,
4141
exists,
4242
install_requirements,
@@ -47,9 +47,8 @@ def test_start_torchserve_default_service_handler(
4747
):
4848
torchserve.start_torchserve()
4949

50-
adapt.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE)
51-
create_config.assert_called_once_with()
52-
exists.assert_called_once_with(REQUIREMENTS_PATH)
50+
set_python_path.assert_called_once_with()
51+
create_config.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE)
5352
install_requirements.assert_called_once_with()
5453

5554
ts_model_server_cmd = [
@@ -62,7 +61,7 @@ def test_start_torchserve_default_service_handler(
6261
"--log-config",
6362
torchserve.DEFAULT_TS_LOG_FILE,
6463
"--models",
65-
"model.mar"
64+
"model=/opt/ml/model"
6665
]
6766

6867
subprocess_popen.assert_called_once_with(ts_model_server_cmd)
@@ -76,9 +75,9 @@ def test_start_torchserve_default_service_handler(
7675
@patch("sagemaker_pytorch_serving_container.torchserve._install_requirements")
7776
@patch("os.path.exists", return_value=True)
7877
@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file")
79-
@patch("sagemaker_pytorch_serving_container.torchserve._adapt_to_ts_format")
78+
@patch("sagemaker_pytorch_serving_container.torchserve._set_python_path")
8079
def test_start_torchserve_default_service_handler_multi_model(
81-
adapt,
80+
set_python_path,
8281
create_config,
8382
exists,
8483
install_requirements,
@@ -90,7 +89,9 @@ def test_start_torchserve_default_service_handler_multi_model(
9089
torchserve.ENABLE_MULTI_MODEL = True
9190
torchserve.start_torchserve()
9291
torchserve.ENABLE_MULTI_MODEL = False
93-
create_config.assert_called_once_with()
92+
93+
set_python_path.assert_called_once_with()
94+
create_config.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE)
9495
exists.assert_called_once_with(REQUIREMENTS_PATH)
9596
install_requirements.assert_called_once_with()
9697

@@ -104,74 +105,13 @@ def test_start_torchserve_default_service_handler_multi_model(
104105
"--log-config",
105106
torchserve.DEFAULT_TS_LOG_FILE,
106107
"--models",
107-
"model.mar"
108+
"model=/opt/ml/model"
108109
]
109110

110111
subprocess_popen.assert_called_once_with(ts_model_server_cmd)
111112
sigterm.assert_called_once_with(retrieve.return_value)
112113

113114

114-
@patch("subprocess.call")
115-
@patch("subprocess.Popen")
116-
@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process")
117-
@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler")
118-
@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file")
119-
@patch("sagemaker_pytorch_serving_container.torchserve._adapt_to_ts_format")
120-
def test_start_torchserve_custom_handler_service(
121-
adapt, create_config, sigterm, retrieve, subprocess_popen, subprocess_call
122-
):
123-
handler_service = Mock()
124-
125-
torchserve.start_torchserve(handler_service)
126-
127-
adapt.assert_called_once_with(handler_service)
128-
129-
130-
@patch("sagemaker_pytorch_serving_container.torchserve._set_python_path")
131-
@patch("subprocess.check_call")
132-
@patch("os.makedirs")
133-
@patch("os.path.exists", return_value=False)
134-
def test_adapt_to_ts_format(path_exists, make_dir, subprocess_check_call, set_python_path):
135-
handler_service = Mock()
136-
137-
torchserve._adapt_to_ts_format(handler_service)
138-
139-
path_exists.assert_called_once_with(torchserve.DEFAULT_TS_MODEL_DIRECTORY)
140-
make_dir.assert_called_once_with(torchserve.DEFAULT_TS_MODEL_DIRECTORY)
141-
142-
model_archiver_cmd = [
143-
"torch-model-archiver",
144-
"--model-name",
145-
torchserve.DEFAULT_TS_MODEL_NAME,
146-
"--handler",
147-
handler_service,
148-
"--export-path",
149-
torchserve.DEFAULT_TS_MODEL_DIRECTORY,
150-
"--version",
151-
"1",
152-
"--extra-files",
153-
environment.model_dir
154-
]
155-
156-
subprocess_check_call.assert_called_once_with(model_archiver_cmd)
157-
set_python_path.assert_called_once_with()
158-
159-
160-
@patch("sagemaker_pytorch_serving_container.torchserve._set_python_path")
161-
@patch("subprocess.check_call")
162-
@patch("os.makedirs")
163-
@patch("os.path.exists", return_value=True)
164-
def test_adapt_to_ts_format_existing_path(
165-
path_exists, make_dir, subprocess_check_call, set_python_path
166-
):
167-
handler_service = Mock()
168-
169-
torchserve._adapt_to_ts_format(handler_service)
170-
171-
path_exists.assert_called_once_with(torchserve.DEFAULT_TS_MODEL_DIRECTORY)
172-
make_dir.assert_not_called()
173-
174-
175115
@patch.dict(os.environ, {torchserve.PYTHON_PATH_ENV: PYTHON_PATH}, clear=True)
176116
def test_set_existing_python_path():
177117
torchserve._set_python_path()
@@ -193,7 +133,7 @@ def test_new_python_path():
193133
@patch("sagemaker_pytorch_serving_container.torchserve._generate_ts_config_properties")
194134
@patch("sagemaker_inference.utils.write_file")
195135
def test_create_torchserve_config_file(write_file, generate_ts_config_props):
196-
torchserve._create_torchserve_config_file()
136+
torchserve._create_torchserve_config_file(torchserve.DEFAULT_HANDLER_SERVICE)
197137

198138
write_file.assert_called_once_with(
199139
torchserve.TS_CONFIG_FILE, generate_ts_config_props.return_value
@@ -211,7 +151,7 @@ def test_generate_ts_config_properties(env, read_file):
211151
env.return_value.model_sever_workerse = model_server_workers
212152
env.return_value.inference_http_port = http_port
213153

214-
ts_config_properties = torchserve._generate_ts_config_properties()
154+
ts_config_properties = torchserve._generate_ts_config_properties(torchserve.DEFAULT_HANDLER_SERVICE)
215155

216156
inference_address = "inference_address=http://0.0.0.0:{}\n".format(http_port)
217157
server_timeout = "default_response_timeout={}\n".format(model_server_timeout)
@@ -228,7 +168,7 @@ def test_generate_ts_config_properties(env, read_file):
228168
def test_generate_ts_config_properties_default_workers(env, read_file):
229169
env.return_value.model_server_workers = None
230170

231-
ts_config_properties = torchserve._generate_ts_config_properties()
171+
ts_config_properties = torchserve._generate_ts_config_properties(torchserve.DEFAULT_HANDLER_SERVICE)
232172

233173
workers = "default_workers_per_model={}".format(None)
234174

@@ -244,7 +184,7 @@ def test_generate_ts_config_properties_multi_model(env, read_file):
244184
env.return_value.model_server_workers = None
245185

246186
torchserve.ENABLE_MULTI_MODEL = True
247-
ts_config_properties = torchserve._generate_ts_config_properties()
187+
ts_config_properties = torchserve._generate_ts_config_properties(torchserve.DEFAULT_HANDLER_SERVICE)
248188
torchserve.ENABLE_MULTI_MODEL = False
249189

250190
workers = "default_workers_per_model={}".format(None)

0 commit comments

Comments
 (0)