Skip to content

Commit 6f8c71d

Browse files
committed
Change Predictor Class on Model object
1 parent 23b6d91 commit 6f8c71d

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

test/integration/sagemaker/test_mnist_default_inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import pytest
1818
import requests
1919
import sagemaker
20-
from sagemaker.pytorch import PyTorchModel
20+
from sagemaker.predictor import RealTimePredictor
21+
from sagemaker.pytorch import PyTorchModel, PyTorchPredictor
2122

2223
from integration import (
2324
model_cpu_tar,
@@ -85,6 +86,7 @@ def _test_default_inference(
8586
pytorch = PyTorchModel(
8687
model_data=model_data,
8788
role="SageMakerRole",
89+
predictor_cls=RealTimePredictor if not accelerator_type else PyTorchPredictor,
8890
entry_point=mnist_script,
8991
image=image_uri,
9092
sagemaker_session=sagemaker_session,

test/resources/resnet18/default_model/code/resnet18.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def transform_fn(model, payload, request_content_type, response_content_type):
1515

1616
logger.info("Invoking user-defined transform function")
1717

18-
if request_content_type != "application/octet-stream":
18+
if request_content_type and request_content_type != "application/octet-stream":
1919
raise RuntimeError(
2020
"Content type must be application/octet-stream. Provided: {0}".format(
2121
request_content_type

0 commit comments

Comments
 (0)