Skip to content

Commit 1ccc481

Browse files
committed
fixing issue to include all model artefacts not only model.pth
1 parent 6936c08 commit 1ccc481

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/sagemaker_pytorch_serving_container/torchserve.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,15 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
110110
def _adapt_to_ts_format(handler_service):
111111
if not os.path.exists(DEFAULT_TS_MODEL_DIRECTORY):
112112
os.makedirs(DEFAULT_TS_MODEL_DIRECTORY)
113-
113+
114+
extra_files = []
115+
extra_files.append(os.path.join(environment.model_dir, DEFAULT_TS_CODE_DIR, environment.Environment().module_name + ".py"))
116+
extra_files+= [os.path.join(environment.model_dir, file) for file
117+
in os.listdir(environment.model_dir)
118+
if os.path.isfile(os.path.join(environment.model_dir, file))]
119+
extra_files.remove(os.path.join(environment.model_dir, DEFAULT_TS_MODEL_SERIALIZED_FILE))
120+
logger.info(extra_files)
121+
114122
model_archiver_cmd = [
115123
"torch-model-archiver",
116124
"--model-name",
@@ -122,7 +130,7 @@ def _adapt_to_ts_format(handler_service):
122130
"--export-path",
123131
DEFAULT_TS_MODEL_DIRECTORY,
124132
"--extra-files",
125-
os.path.join(environment.model_dir, DEFAULT_TS_CODE_DIR, environment.Environment().module_name + ".py"),
133+
','.join(extra_files),
126134
"--version",
127135
"1",
128136
]

0 commit comments

Comments
 (0)