Skip to content

Commit 414c83a

Browse files
committed
Address comments
1 parent 92def12 commit 414c83a

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

test/integration/sagemaker/test_default_inference.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,12 @@ def _test_default_inference(
102102
sagemaker_session=sagemaker_session,
103103
)
104104
with timeout_and_delete_endpoint(endpoint_name, sagemaker_session, minutes=30):
105-
# Use accelerator type to differentiate EI vs. CPU and GPU. Don't use processor value
106-
if accelerator_type is not None:
107-
predictor = pytorch.deploy(
108-
initial_instance_count=1,
109-
instance_type=instance_type,
110-
accelerator_type=accelerator_type,
111-
endpoint_name=endpoint_name,
112-
)
113-
else:
114-
predictor = pytorch.deploy(
115-
initial_instance_count=1, instance_type=instance_type, endpoint_name=endpoint_name
116-
)
105+
predictor = pytorch.deploy(
106+
initial_instance_count=1,
107+
instance_type=instance_type,
108+
accelerator_type=accelerator_type,
109+
endpoint_name=endpoint_name,
110+
)
117111

118112
if accelerator_type:
119113
batch_size = 100

test/resources/mnist/default_model_eia/code/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def predict_fn(input_data, model):
2626
logger.info('Performing EIA inference with Torch JIT context with input of size {}'.format(input_data.shape))
2727
# With EI, client instance should be CPU for cost-efficiency.
2828
# Sub-graphs with unsupported arguments run locally. Server runs with CUDA
29-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29+
device = torch.device('cpu')
3030
model = model.to(device)
3131
input_data = input_data.to(device)
3232
with torch.no_grad():

0 commit comments

Comments
 (0)