Skip to content

Commit b077737

Browse files
authored
documentation: add documentation for XGBoost (#1350)
1 parent d32c5f6 commit b077737

File tree

5 files changed

+256
-7
lines changed

5 files changed

+256
-7
lines changed

doc/index.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@ A managed environment for TensorFlow training and hosting on Amazon SageMaker
8484

8585
sagemaker.tensorflow
8686

87+
*******
88+
XGBoost
89+
*******
90+
A managed environment for XGBoost training and hosting on Amazon SageMaker
91+
92+
.. toctree::
93+
:maxdepth: 1
94+
95+
using_xgboost
96+
97+
.. toctree::
98+
:maxdepth: 2
99+
100+
xgboost
101+
102+
87103
************
88104
Scikit-Learn
89105
************

doc/using_xgboost.rst

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#########################################
2+
Use XGBoost with the SageMaker Python SDK
3+
#########################################
4+
5+
.. contents::
6+
7+
eXtreme Gradient Boosting (XGBoost) is a popular and efficient machine learning algorithm used for regression and classification tasks on tabular datasets.
8+
It implements a technique known as gradient boosting on trees, which performs remarkably well in machine learning competitions.
9+
10+
Amazon SageMaker supports two ways to use the XGBoost algorithm:
11+
12+
* XGBoost built-in algorithm
13+
* XGBoost open source algorithm
14+
15+
The XGBoost open source algorithm provides the following benefits over the built-in algorithm:
16+
17+
* Latest version - The open source XGBoost algorithm typically supports a more recent version of XGBoost.
18+
To see the XGBoost version that is currently supported,
19+
see `XGBoost SageMaker Estimators and Models <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/xgboost#xgboost-sagemaker-estimators-and-models>`__.
20+
* Flexibility - Take advantage of the full range of XGBoost functionality, such as cross-validation support.
21+
You can add custom pre- and post-processing logic and run additional code after training.
22+
* Scalability - The XGBoost open source algorithm has a more efficient implementation of distributed training,
23+
which enables it to scale out to more instances and reduce out-of-memory errors.
24+
* Extensibility - Because the open source XGBoost container is open source,
25+
you can extend the container to install additional libraries and change the version of XGBoost that the container uses.
26+
For an example notebook that shows how to extend SageMaker containers, see `Extending our PyTorch containers <https://github.com/awslabs/amazon-sagemaker-examples/blob/master/advanced_functionality/pytorch_extending_our_containers/pytorch_extending_our_containers.ipynb>`__.
27+
28+
29+
***********************************
30+
Use XGBoost as a Built-in Algortihm
31+
***********************************
32+
33+
Amazon SageMaker provides XGBoost as a built-in algorithm that you can use like other built-in algorithms.
34+
Using the built-in algorithm version of XGBoost is simpler than using the open source version, because you don't have to write a training script.
35+
If you don't need the features and flexibility of open source XGBoost, consider using the built-in version.
36+
For information about using the Amazon SageMaker XGBoost built-in algorithm, see `XGBoost Algorithm <https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html>`__
37+
in the *Amazon SageMaker Developer Guide*.
38+
39+
*************************************
40+
Use the Open Source XGBoost Algorithm
41+
*************************************
42+
43+
If you want the flexibility and additional features that it provides, use the SageMaker open source XGBoost algorithm.
44+
45+
For a complete example of using the open source XGBoost algorithm, see the sample notebook at
46+
https://github.com/awslabs/amazon-sagemaker-examples/blob/master/introduction_to_amazon_algorithms/xgboost_abalone/xgboost_abalone_dist_script_mode.ipynb.
47+
48+
49+
Train a Model with Open Source XGBoost
50+
======================================
51+
52+
To train a model by using the Amazon SageMaker open source XGBoost algorithm:
53+
54+
.. |create xgboost estimator| replace:: Create a ``sagemaker.xgboost.XGBoost estimator``
55+
.. _create xgboost estimator: #create-an-estimator
56+
57+
.. |call fit| replace:: Call the estimator's ``fit`` method
58+
.. _call fit: #call-the-fit-method
59+
60+
1. `Prepare a training script <#prepare-a-training-script>`_
61+
2. |create xgboost estimator|_
62+
3. |call fit|_
63+
64+
Prepare a Training Script
65+
-------------------------
66+
67+
A typical training script loads data from the input channels, configures training with hyperparameters, trains a model,
68+
and saves a model to ``model_dir`` so that it can be hosted later.
69+
Hyperparameters are passed to your script as arguments and can be retrieved with an ``argparse.ArgumentParser`` instance.
70+
For information about ``argparse.ArgumentParser``, see `argparse <https://docs.python.org/3/library/argparse.html>`__ in the Python documentation.
71+
72+
73+
For a complete example of an XGBoost training script, see https://github.com/awslabs/amazon-sagemaker-examples/blob/master/introduction_to_amazon_algorithms/xgboost_abalone/abalone.py.
74+
75+
The training script is very similar to a training script you might run outside of Amazon SageMaker,
76+
but you can access useful properties about the training environment through various environment variables, including the following:
77+
78+
* ``SM_MODEL_DIR``: A string that represents the path where the training job writes the model artifacts to.
79+
After training, artifacts in this directory are uploaded to Amazon S3 for model hosting.
80+
* ``SM_NUM_GPUS``: An integer representing the number of GPUs available to the host.
81+
* ``SM_CHANNEL_XXXX``: A string that represents the path to the directory that contains the input data for the specified channel.
82+
For example, if you specify two input channels in the MXNet estimator's ``fit`` call, named 'train' and 'test', the environment variables ``SM_CHANNEL_TRAIN`` and ``SM_CHANNEL_TEST`` are set.
83+
* ``SM_HPS``: A JSON dump of the hyperparameters preserving JSON types (boolean, integer, etc.)
84+
85+
For the exhaustive list of available environment variables, see the `SageMaker Containers documentation <https://github.com/aws/sagemaker-containers#list-of-provided-environment-variables-by-sagemaker-containers>`__.
86+
87+
Let's look at the main elements of the script. Starting with the ``__main__`` guard,
88+
use a parser to read the hyperparameters passed to the estimator when creating the training job.
89+
These hyperparameters are made available as arguments to our input script.
90+
We also parse a number of Amazon SageMaker-specific environment variables to get information about the training environment,
91+
such as the location of input data and location where we want to save the model.
92+
93+
.. code:: python
94+
95+
if __name__ == '__main__':
96+
parser = argparse.ArgumentParser()
97+
98+
# Hyperparameters are described here
99+
parser.add_argument('--num_round', type=int)
100+
parser.add_argument('--max_depth', type=int, default=5)
101+
parser.add_argument('--eta', type=float, default=0.2)
102+
parser.add_argument('--objective', type=str, default='reg:squarederror')
103+
104+
# SageMaker specific arguments. Defaults are set in the environment variables.
105+
parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
106+
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
107+
parser.add_argument('--validation', type=str, default=os.environ['SM_CHANNEL_VALIDATION'])
108+
109+
args = parser.parse_args()
110+
111+
train_hp = {
112+
'max_depth': args.max_depth,
113+
'eta': args.eta,
114+
'gamma': args.gamma,
115+
'min_child_weight': args.min_child_weight,
116+
'subsample': args.subsample,
117+
'silent': args.silent,
118+
'objective': args.objective
119+
}
120+
121+
dtrain = xgb.DMatrix(args.train)
122+
dval = xgb.DMatrix(args.validation)
123+
watchlist = [(dtrain, 'train'), (dval, 'validation')] if dval is not None else [(dtrain, 'train')]
124+
125+
callbacks = []
126+
prev_checkpoint, n_iterations_prev_run = add_checkpointing(callbacks)
127+
# If checkpoint is found then we reduce num_boost_round by previously run number of iterations
128+
129+
bst = xgb.train(
130+
params=train_hp,
131+
dtrain=dtrain,
132+
evals=watchlist,
133+
num_boost_round=(args.num_round - n_iterations_prev_run),
134+
xgb_model=prev_checkpoint,
135+
callbacks=callbacks
136+
)
137+
138+
# Save the model to the location specified by ``model_dir``
139+
model_location = args.model_dir + '/xgboost-model'
140+
pkl.dump(bst, open(model_location, 'wb'))
141+
logging.info("Stored trained model at {}".format(model_location))
142+
143+
Create an Estimator
144+
-------------------
145+
After you create your training script, create an instance of the :class:`sagemaker.xgboost.estimator.XGBoost` estimator.
146+
Pass an IAM role that has the permissions necessary to run an Amazon SageMaker training job,
147+
the type and number of instances to use for the training job,
148+
and a dictionary of the hyperparameters to pass to the training script.
149+
150+
.. code::
151+
152+
from sagemaker.xgboost.estimator import XGBoost
153+
154+
xgb_estimator = XGBoost(
155+
entry_point="abalone.py",
156+
hyperparameters=hyperparameters,
157+
role=role,
158+
train_instance_count=1,
159+
train_instance_type="ml.m5.2xlarge",
160+
framework_version="0.90-1",
161+
)
162+
163+
164+
Call the fit Method
165+
-------------------
166+
167+
After you create an estimator, call the ``fit`` method to run the training job.
168+
169+
.. code::
170+
171+
xgb_script_mode_estimator.fit({"train": train_input})
172+
173+
174+
175+
Deploy Open Source XGBoost Models
176+
=================================
177+
178+
After the training job finishes, call the ``deploy`` method of the estimator to create a predictor that you can use to get inferences from your trained model.
179+
180+
.. code::
181+
182+
predictor = xgb_script_mode_estimator.deploy(initial_instance_count=1, instance_type="ml.m5.xlarge")
183+
test_data = xgboost.DMatrix('/path/to/data')
184+
predictor.predict(test_data)
185+
186+
Customize inference
187+
-------------------
188+
189+
In your inference script, which can be either in the same file as your training script or in a separate file,
190+
you can customize the inference behavior by implementing the following functions:
191+
* ``input_fn`` - how input data is handled
192+
* ``predict_fn`` - how the model is invoked
193+
* ``output_fn`` - How the response data is handled
194+
195+
These functions are optional. If you want to use the default implementations, do not implement them in your training script.
196+
197+
198+
*************************
199+
SageMaker XGBoost Classes
200+
*************************
201+
202+
For information about the SageMaker Python SDK XGBoost classes, see the following topics:
203+
204+
* :class:`sagemaker.xgboost.estimator.XGBoost`
205+
* :class:`sagemaker.xgboost.model.XGBoostModel`
206+
* :class:`sagemaker.xgboost.model.XGBoostPredictor`
207+
208+
***********************************
209+
SageMaker XGBoost Docker Containers
210+
***********************************
211+
212+
For information about SageMaker XGBoost Docker container and its dependencies, see `SageMaker XGBoost Container <https://github.com/aws/sagemaker-xgboost-container>`_.
213+
214+
215+

doc/xgboost.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
XGBoost Classes for Open Source Version
2+
---------------------------------------
3+
4+
The Amazon SageMaker XGBoost open source framework algorithm.
5+
6+
.. autoclass:: sagemaker.xgboost.estimator.XGBoost
7+
:members:
8+
9+
10+
.. autoclass:: sagemaker.xgboost.model.XGBoostModel
11+
:members:
12+
:undoc-members:
13+
:show-inheritance:
14+
15+
.. autoclass:: sagemaker.xgboost.model.XGBoostPredictor
16+
:members:
17+
:undoc-members:
18+
:show-inheritance:

src/sagemaker/xgboost/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def __init__(
6767
Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this
6868
Estimator. After training is complete, calling
6969
:meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted SageMaker endpoint
70-
and returns an :class:`~sagemaker.amazon.xgboost.model.XGBoostPredictor` instance that
71-
can be used to perform inference against the hosted model.
70+
and returns an :class:`~sagemaker.amazon.xgboost.model.XGBoostPredictor` instance that
71+
can be used to perform inference against the hosted model.
7272
7373
Technical documentation on preparing XGBoost scripts for SageMaker training and using the
7474
XGBoost Estimator is available on the project home-page:

src/sagemaker/xgboost/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,16 @@ def __init__(
108108
self.model_server_workers = model_server_workers
109109

110110
def prepare_container_def(self, instance_type, accelerator_type=None):
111-
"""Return a container definition with framework configuration set in model environment
112-
variables.
111+
"""Return a container definition with framework configuration
112+
set in model environment variables.
113113
114114
Args:
115115
instance_type (str): The EC2 instance type to deploy this Model to. For example,
116116
'ml.m5.xlarge'.
117117
accelerator_type (str): The Elastic Inference accelerator type to deploy to the
118-
instance for loading and making inferences to the model. For example,
119-
'ml.eia1.medium'.
120-
Note: accelerator types are not supported by XGBoostModel.
118+
instance for loading and making inferences to the model. For example,
119+
'ml.eia1.medium'.
120+
Note: accelerator types are not supported by XGBoostModel.
121121
122122
Returns:
123123
dict[str, str]: A container definition object usable with the CreateModel API.

0 commit comments

Comments
 (0)