Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/autogluon/cloud/backend/sagemaker_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,6 @@ def predict_real_time(
test_data_image_column: Optional[str] = None,
accept: str = "application/x-parquet",
inference_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[pd.DataFrame, pd.Series]:
"""
Predict with the deployed SageMaker endpoint. A deployed SageMaker endpoint is required.
Expand Down Expand Up @@ -595,7 +594,6 @@ def predict_proba_real_time(
test_data_image_column: Optional[str] = None,
accept: str = "application/x-parquet",
inference_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[pd.DataFrame, pd.Series]:
"""
Predict probability with the deployed SageMaker endpoint. A deployed SageMaker endpoint is required.
Expand Down Expand Up @@ -704,6 +702,7 @@ def predict(
instance_count: int = 1,
custom_image_uri: Optional[str] = None,
wait: bool = True,
inference_kwargs: Optional[Dict[str, Any]] = None,
Comment thread
suzhoum marked this conversation as resolved.
download: bool = True,
persist: bool = True,
save_path: Optional[str] = None,
Expand Down Expand Up @@ -783,6 +782,7 @@ def predict(
instance_count=instance_count,
custom_image_uri=custom_image_uri,
wait=wait,
inference_kwargs=inference_kwargs,
download=download,
persist=persist,
save_path=save_path,
Expand All @@ -805,6 +805,7 @@ def predict_proba(
instance_count: int = 1,
custom_image_uri: Optional[str] = None,
wait: bool = True,
inference_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring

download: bool = True,
persist: bool = True,
save_path: Optional[str] = None,
Expand Down Expand Up @@ -889,6 +890,7 @@ def predict_proba(
instance_count=instance_count,
custom_image_uri=custom_image_uri,
wait=wait,
inference_kwargs=inference_kwargs,
download=download,
persist=persist,
save_path=save_path,
Expand Down Expand Up @@ -1133,6 +1135,7 @@ def _predict(
instance_count=1,
custom_image_uri=None,
wait=True,
inference_kwargs=None,
download=True,
persist=True,
save_path=None,
Expand Down Expand Up @@ -1256,6 +1259,7 @@ def _predict(
transformer_kwargs=transformer_kwargs,
model_kwargs=model_kwargs,
repack_model=repack_model,
inference_kwargs=inference_kwargs,
**transform_kwargs,
)
self._batch_transform_jobs[job_name] = batch_transform_job
Expand Down
7 changes: 7 additions & 0 deletions src/autogluon/cloud/job/sagemaker_job.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
from abc import abstractmethod
from typing import Optional
Expand Down Expand Up @@ -257,6 +258,7 @@ def run(
model_kwargs,
transformer_kwargs,
repack_model=False,
inference_kwargs=None,
**kwargs,
):
self._local_mode = instance_type in (LOCAL_MODE, LOCAL_MODE_GPU)
Expand All @@ -265,6 +267,10 @@ def run(
else:
model_cls = AutoGluonNonRepackInferenceModel
logger.log(20, "Creating inference model...")
inference_kwargs_str = json.dumps(inference_kwargs) if inference_kwargs is not None else None
Comment thread
suzhoum marked this conversation as resolved.
Outdated
env = {}
if len(inference_kwargs_str) > 0:
env["inference_kwargs"] = inference_kwargs_str
model = model_cls(
model_data=model_data,
role=role,
Expand All @@ -275,6 +281,7 @@ def run(
custom_image_uri=custom_image_uri,
entry_point=entry_point,
predictor_cls=predictor_cls,
env=env,
**model_kwargs,
)
logger.log(20, "Inference model created successfully")
Expand Down
6 changes: 5 additions & 1 deletion src/autogluon/cloud/predictor/cloud_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def predict_proba_real_time(
"""
self._validate_inference_kwargs(inference_kwargs=kwargs)
return self.backend.predict_proba_real_time(
test_data=test_data, test_data_image_column=test_data_image_column, accept=accept
test_data=test_data, test_data_image_column=test_data_image_column, accept=accept, inference_kwargs=kwargs
)

def predict(
Expand All @@ -556,6 +556,7 @@ def predict(
custom_image_uri: Optional[str] = None,
wait: bool = True,
backend_kwargs: Optional[Dict] = None,
**kwargs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add example code should be added to tutorials that showcase specifying kwargs. Otherwise it will be hard for users to realize how to do this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I will add some tutorials with this PR.

) -> Optional[pd.Series]:
"""
Batch inference.
Expand Down Expand Up @@ -632,6 +633,7 @@ def predict(
instance_count=instance_count,
custom_image_uri=custom_image_uri,
wait=wait,
inference_kwargs=kwargs,
**backend_kwargs,
)

Expand All @@ -648,6 +650,7 @@ def predict_proba(
custom_image_uri: Optional[str] = None,
wait: bool = True,
backend_kwargs: Optional[Dict] = None,
**kwargs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring

) -> Optional[Union[Tuple[pd.Series, Union[pd.DataFrame, pd.Series]], Union[pd.DataFrame, pd.Series]]]:
"""
Batch inference
Expand Down Expand Up @@ -730,6 +733,7 @@ def predict_proba(
instance_count=instance_count,
custom_image_uri=custom_image_uri,
wait=wait,
inference_kwargs=kwargs,
**backend_kwargs,
)

Expand Down
25 changes: 18 additions & 7 deletions src/autogluon/cloud/scripts/sagemaker_scripts/tabular_serve.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# flake8: noqa
import base64
import hashlib
import json
import logging
import os
import pickle
import sys
from io import BytesIO, StringIO

import pandas as pd
Expand All @@ -13,6 +16,8 @@
from autogluon.tabular import TabularPredictor

image_dir = os.path.join("/tmp", "ag_images")
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)


def _save_image_and_update_dataframe_column(bytes):
Expand All @@ -31,15 +36,23 @@ def _save_image_and_update_dataframe_column(bytes):

def model_fn(model_dir):
"""loads model from previously saved artifact"""
model = TabularPredictor.load(model_dir)
model.persist_models()
globals()["column_names"] = model.original_features
logger.info("Loading the model")
try:
model = TabularPredictor.load(model_dir)
model.persist_models()
globals()["column_names"] = model.original_features
except Exception as e:
logger.error(f"Error loading the model: {str(e)}")
raise e

return model


def transform_fn(model, request_body, input_content_type, output_content_type="application/json"):
inference_kwargs = {}
inference_kwargs = os.environ.get("inference_kwargs", {})
if inference_kwargs:
inference_kwargs = json.loads(inference_kwargs)

if input_content_type == "application/x-parquet":
buf = BytesIO(request_body)
data = pd.read_parquet(buf)
Expand All @@ -60,9 +73,7 @@ def transform_fn(model, request_body, input_content_type, output_content_type="a
buf = bytes(request_body)
payload = pickle.loads(buf)
data = pd.read_parquet(BytesIO(payload["data"]))
inference_kwargs = payload["inference_kwargs"]
if inference_kwargs is None:
inference_kwargs = {}
inference_kwargs = payload.get("inference_kwargs", {})

else:
raise ValueError(f"{input_content_type} input content type not supported.")
Expand Down