diff --git a/src/sagemaker_pytorch_serving_container/torchserve.py b/src/sagemaker_pytorch_serving_container/torchserve.py index 95362352..5f88cff6 100644 --- a/src/sagemaker_pytorch_serving_container/torchserve.py +++ b/src/sagemaker_pytorch_serving_container/torchserve.py @@ -110,7 +110,13 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE): def _adapt_to_ts_format(handler_service): if not os.path.exists(DEFAULT_TS_MODEL_DIRECTORY): os.makedirs(DEFAULT_TS_MODEL_DIRECTORY) - + + extra_files = [] + extra_files.append(os.path.join(environment.model_dir, DEFAULT_TS_CODE_DIR, environment.Environment().module_name + ".py")) + extra_files+= [os.path.join(environment.model_dir, file) for file + in os.listdir(environment.model_dir) + if os.path.isfile(os.path.join(environment.model_dir, file)) and file != DEFAULT_TS_MODEL_SERIALIZED_FILE ] + model_archiver_cmd = [ "torch-model-archiver", "--model-name", @@ -122,7 +128,7 @@ def _adapt_to_ts_format(handler_service): "--export-path", DEFAULT_TS_MODEL_DIRECTORY, "--extra-files", - os.path.join(environment.model_dir, DEFAULT_TS_CODE_DIR, environment.Environment().module_name + ".py"), + ','.join(extra_files), "--version", "1", ]