Skip to content

Commit bd8e42a

Browse files
chad119Chad Chiangnargokul
authored andcommitted
add integ test for base_model_builder_deploy and remove print statement (#1612)
* Fix: move the functionality from latest_container_image to retrieve * address some comments from Gokul and add unit test * remove extra functions and rewrite the test * fix unit test * fix for other unit test * unit test fix * fix unit test: add one more condition * more unit tests fix * remove redundant files * remove the special condition and fix the unit test * add integ test for base_model_builder_deploy and remove print statement * add endpoint name for ipynb * solve the integ test --------- Co-authored-by: Chad Chiang <[email protected]> Co-authored-by: Gokul Anantha Narayanan <[email protected]>
1 parent 6985ec5 commit bd8e42a

File tree

3 files changed

+193
-1
lines changed

3 files changed

+193
-1
lines changed

src/sagemaker/image_uris.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,6 @@ def _get_inference_tool(inference_tool, instance_type):
460460

461461
def _get_latest_versions(list_of_versions):
462462
"""Extract the latest version from the input list of available versions."""
463-
print("SORT")
464463
return sorted(list_of_versions, reverse=True)[0]
465464

466465

src/sagemaker/modules/testing_notebooks/base_model_builder_deploy.ipynb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@
239239
"outputs": [],
240240
"source": [
241241
"real_time_predictor = model_builder.deploy(\n",
242+
" endpoint_name=\"test-real-time-predictor\",\n",
242243
" initial_instance_count=1)"
243244
]
244245
},
@@ -265,6 +266,7 @@
265266
"from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig\n",
266267
"\n",
267268
"serverless_predictor = model_builder.deploy(\n",
269+
" endpoint_name=\"test-serverless-predictor\",\n",
268270
" inference_config=ServerlessInferenceConfig())"
269271
]
270272
},
@@ -292,6 +294,7 @@
292294
"from sagemaker.async_inference import AsyncInferenceConfig\n",
293295
"\n",
294296
"async_predictor = model_builder.deploy(\n",
297+
" endpoint_name=\"test-async-predictor\",\n",
295298
" inference_config=AsyncInferenceConfig(\n",
296299
" output_path=s3_path_join(\"s3://\", bucket, \"async_inference/output\")),\n",
297300
")"
@@ -316,6 +319,7 @@
316319
"\n",
317320
"\n",
318321
"batch_predictor = model_builder.deploy(\n",
322+
" endpoint_name=\"test-batch-predictor\",\n",
319323
" initial_instance_count=1,\n",
320324
" inference_config=BatchTransformInferenceConfig(\n",
321325
" instance_count=1,\n",
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
import unittest
17+
18+
from docker.utils import exclude_paths
19+
20+
from sagemaker import Session, get_execution_role
21+
22+
from sklearn.datasets import load_iris
23+
from sklearn.model_selection import train_test_split
24+
25+
import pandas as pd
26+
27+
import os
28+
29+
from sagemaker_core.main.shapes import AlgorithmSpecification, Channel, DataSource, S3DataSource, OutputDataConfig, \
30+
ResourceConfig, StoppingCondition
31+
import uuid
32+
from sagemaker.serve.builder.model_builder import ModelBuilder, BatchTransformInferenceConfig
33+
import pandas as pd
34+
import numpy as np
35+
from sagemaker.serve import InferenceSpec, SchemaBuilder
36+
from sagemaker_core.main.resources import TrainingJob
37+
from xgboost import XGBClassifier
38+
39+
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
40+
41+
from sagemaker.s3_utils import s3_path_join
42+
from sagemaker.async_inference import AsyncInferenceConfig
43+
from sagemaker_core.main.resources import Endpoint
44+
45+
46+
class TestBaseModelBuilderDeploy(unittest.TestCase):
47+
def setUp(self):
48+
sagemaker_session = Session()
49+
role = get_execution_role()
50+
region = sagemaker_session.boto_region_name
51+
bucket = sagemaker_session.default_bucket()
52+
self.bucket = bucket
53+
# Get IRIS Data
54+
55+
iris = load_iris()
56+
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
57+
iris_df['target'] = iris.target
58+
59+
# Prepare Data
60+
61+
os.makedirs('./data', exist_ok=True)
62+
63+
iris_df = iris_df[['target'] + [col for col in iris_df.columns if col != 'target']]
64+
65+
train_data, test_data = train_test_split(iris_df, test_size=0.2, random_state=42)
66+
67+
train_data.to_csv('./data/train.csv', index=False, header=False)
68+
test_data.to_csv('./data/test.csv', index=False, header=False)
69+
70+
# Remove the target column from the testing data. We will use this to call invoke_endpoint later
71+
test_data_no_target = test_data.drop('target', axis=1)
72+
73+
prefix = "DEMO-scikit-iris"
74+
TRAIN_DATA = "train.csv"
75+
DATA_DIRECTORY = "data"
76+
77+
train_input = sagemaker_session.upload_data(
78+
DATA_DIRECTORY, bucket=bucket, key_prefix="{}/{}".format(prefix, DATA_DIRECTORY)
79+
)
80+
81+
s3_input_path = "s3://{}/{}/data/{}".format(bucket, prefix, TRAIN_DATA)
82+
s3_output_path = "s3://{}/{}/output".format(bucket, prefix)
83+
84+
print(s3_input_path)
85+
print(s3_output_path)
86+
87+
image = "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:latest"
88+
89+
90+
class XGBoostSpec(InferenceSpec):
91+
def load(self, model_dir: str):
92+
print(model_dir)
93+
model = XGBClassifier()
94+
model.load_model(model_dir + "/xgboost-model")
95+
return model
96+
97+
def invoke(self, input_object: object, model: object):
98+
prediction_probabilities = model.predict_proba(input_object)
99+
predictions = np.argmax(prediction_probabilities, axis=1)
100+
return predictions
101+
102+
103+
data = {
104+
'Name': ['Alice', 'Bob', 'Charlie']
105+
}
106+
df = pd.DataFrame(data)
107+
training_job_name = str(uuid.uuid4())
108+
schema_builder = SchemaBuilder(sample_input=df, sample_output=df)
109+
110+
training_job = TrainingJob.create(
111+
training_job_name=training_job_name,
112+
hyper_parameters={
113+
'objective': 'multi:softmax',
114+
'num_class': '3',
115+
'num_round': '10',
116+
'eval_metric': 'merror'
117+
},
118+
algorithm_specification=AlgorithmSpecification(
119+
training_image=image,
120+
training_input_mode='File'
121+
),
122+
role_arn=role,
123+
input_data_config=[
124+
Channel(
125+
channel_name='train',
126+
content_type='csv',
127+
compression_type='None',
128+
record_wrapper_type='None',
129+
data_source=DataSource(
130+
s3_data_source=S3DataSource(
131+
s3_data_type='S3Prefix',
132+
s3_uri=s3_input_path,
133+
s3_data_distribution_type='FullyReplicated'
134+
)
135+
)
136+
)
137+
],
138+
output_data_config=OutputDataConfig(
139+
s3_output_path=s3_output_path
140+
),
141+
resource_config=ResourceConfig(
142+
instance_type='ml.m4.xlarge',
143+
instance_count=1,
144+
volume_size_in_gb=30
145+
),
146+
stopping_condition=StoppingCondition(
147+
max_runtime_in_seconds=600
148+
)
149+
)
150+
training_job.wait()
151+
152+
self.model_builder = ModelBuilder(
153+
name="ModelBuilderTest",
154+
model_path=training_job.model_artifacts.s3_model_artifacts,
155+
role_arn=role,
156+
inference_spec=XGBoostSpec(),
157+
image_uri=image,
158+
schema_builder=schema_builder,
159+
instance_type="ml.c6i.xlarge"
160+
)
161+
self.model_builder.build()
162+
163+
def test_real_time_deployment(self):
164+
real_time_predictor = self.model_builder.deploy(
165+
endpoint_name="test",
166+
initial_instance_count=1)
167+
168+
assert real_time_predictor is not None
169+
170+
def test_serverless_deployment(self):
171+
serverless_predictor = self.model_builder.deploy(
172+
endpoint_name="test1",
173+
inference_config=ServerlessInferenceConfig())
174+
175+
assert serverless_predictor is not None
176+
177+
def test_async_deployment(self):
178+
async_predictor = self.model_builder.deploy(
179+
endpoint_name="test2",
180+
inference_config=AsyncInferenceConfig(
181+
output_path=s3_path_join("s3://", self.bucket, "async_inference/output")),
182+
)
183+
184+
assert async_predictor is not None
185+
186+
def tearDown(self):
187+
endpoints = Endpoint.get_all()
188+
for endpoint in endpoints:
189+
endpoint.delete()

0 commit comments

Comments
 (0)