Skip to content

Commit ea539ce

Browse files
committed
feat(quickstarts): add PyTorch quickstart
Signed-off-by: Xe Iaso <[email protected]>
1 parent 98feab6 commit ea539ce

File tree

2 files changed

+353
-0
lines changed

2 files changed

+353
-0
lines changed

docs/quickstarts/pytorch.mdx

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
# PyTorch Quickstart
2+
3+
[PyTorch](https://pytorch.org/) is an open-source machine learning framework
4+
that allows you to define, train, and deploy deep neural networks using a
5+
simple, Python-first approach. It's built around tensor computations, which are
6+
like NumPy arrays but with powerful GPU acceleration. PyTorch uses an automatic
7+
differentiation engine to build dynamic computational graphs, making it highly
8+
flexible and intuitive for both research and development. The framework is
9+
supported by a rich ecosystem of tools and libraries for computer vision,
10+
natural language processing, and production deployment.
11+
12+
To get started training your AI models with PyTorch using data stored in Tigris,
13+
you need to do the following things:
14+
15+
- Create a new bucket at [storage.new](https://storage.new)
16+
- Create an access keypair for that bucket at
17+
[storage.new/accesskey](https://storage.new/accesskey)
18+
- Install the S3 connector for PyTorch
19+
- Configure your datasets
20+
- Run training jobs
21+
22+
## 1. Create a new bucket
23+
24+
Open [storage.new](https://storage.new) in your web browser.
25+
26+
Give your bucket a name and select what [storage tier](../objects/tiers.md) it
27+
should use by default. As a general rule of thumb:
28+
29+
- Standard is the default. If you're not sure what you want, pick standard.
30+
- Infrequent Access is cheaper than Standard, but charges per gigabyte of
31+
retrieval.
32+
- Instant Retrieval Archive is for long-term storage where you might need urgent
33+
access at any moment.
34+
- Archive is for long-term storage where you don't mind having to wait for data
35+
to be brought out of cold storage.
36+
37+
Click "Create".
38+
39+
## 2. Create an access keypair for that bucket
40+
41+
Open [storage.new/accesskey](https://storage.new/accesskey) in your web browser.
42+
43+
Give the keypair a name. This name will be shown in your list of access keys, so
44+
be sure to make it descriptive enough that you can figure out what it's for
45+
later.
46+
47+
You can either give this key access to all of the buckets you have access to or
48+
grant access to an individual bucket by name. Type the name of your bucket and
49+
give it Editor permissions.
50+
51+
Click "Create".
52+
53+
Copy the Access Key ID, Secret Access Key, and other values into a safe place
54+
such as your password manager. Tigris will not show you the Secret Access Key
55+
again.
56+
57+
## 3. Install the S3 connector for PyTorch
58+
59+
Install the
60+
[s3torchconnector](https://github.com/awslabs/s3-connector-for-pytorch) package.
61+
Depending on your environment, the command could look like this:
62+
63+
```sh
64+
pip install s3torchconnector
65+
```
66+
67+
If you are not sure how to install Python packages in your environment, please
68+
consult an expert.
69+
70+
## 4. Configure your datasets
71+
72+
After installing that package, import the relevant classes into your training
73+
code:
74+
75+
```py
76+
from s3torchconnector import S3IterableDataset, S3MapDataset, S3ClientConfig
77+
```
78+
79+
Now decide whether you need a **Map-Style** or **Iterative-Style** dataset:
80+
81+
- **Map-Style** (`S3MapDataset`): Presents the S3 objects as a random-access
82+
dataset (supports `len()` and indexing). It will eagerly list all objects
83+
under the given prefix when first accessed, which can be slow or
84+
memory-intensive if you have millions of objects. Use this if you need
85+
arbitrary index-based access or shuffling of the entire dataset in memory.
86+
This is also best if you have finite datasets such as the text of Wikipedia or
87+
a historical archive of chat logs.
88+
- **Iterative-Style** (`S3IterableDataset`): Streams the S3 objects sequentially
89+
as you iterate, without preloading the whole list. This is ideal for large
90+
datasets where you want to stream data in batches as it’s built for streaming
91+
sequential data access patterns. You sacrifice random access, but gain
92+
efficiency and lower memory overhead for large-scale data. This is best when
93+
you have infinite or constantly changing datasets that cannot possibly fit
94+
into memory such as every Twitter post ever written or a statistical fraction
95+
of website pages.
96+
97+
For a streaming training workflow, S3IterableDataset is typically the best
98+
choice. Let’s create an iterable dataset from a Tigris bucket:
99+
100+
```py
101+
# Parameters for your dataset location on Tigris
102+
bucket_name = "my-dataset-bucket"
103+
prefix = "train/images" # folder/path inside the bucket (or "" for entire bucket)
104+
dataset_uri = f"s3://{bucket_name}/{prefix}"
105+
106+
# (Optional) Prepare an S3 client config (e.g., to adjust performance settings)
107+
cfg = S3ClientConfig() # default config (10 Gbps target, 8 MiB part size, etc.)
108+
109+
# Create an iterable dataset from the Tigris bucket
110+
dataset = S3IterableDataset.from_prefix(
111+
dataset_uri,
112+
region="auto", # Region parameter (Tigris is global, so use "auto")
113+
endpoint="https://t3.storage.dev", # Tigris S3 endpoint
114+
transform=None, # we'll set a transform in the next step
115+
s3client_config=cfg,
116+
enable_sharding=True # enable sharding across DataLoader workers (explained later)
117+
)
118+
```
119+
120+
In the code above, we pass the S3 URI of our dataset and specify the custom
121+
endpoint and region. The connector will connect to t3.storage.dev instead of
122+
Amazon, using our provided credentials. The s3client_config=cfg is optional – by
123+
default it’s tuned for high throughput (e.g. ~10 Gbps target with multi-part
124+
downloads) and typically doesn’t need adjustment. We enabled
125+
enable_sharding=True so that if we use multiple data-loading workers, the
126+
dataset will automatically partition the data among them (more on this in
127+
section 4).
128+
129+
**Map-Style Example (optional)**: If you wanted to use a map-style dataset
130+
instead, you would call `S3MapDataset.from_prefix` similarly. For example:
131+
132+
```py
133+
map_dataset = S3MapDataset.from_prefix(
134+
dataset_uri,
135+
region="auto",
136+
endpoint="https://t3.storage.dev",
137+
s3client_config=cfg,
138+
)
139+
140+
print(len(map_dataset)) # triggers listing all objects under the prefix
141+
sample = map_dataset[0] # get first sample (S3 object)
142+
print(sample.key, sample.read()[:100])
143+
```
144+
145+
This will list all objects under the prefix and allow indexed access. Keep in
146+
mind that the initial listing can take time and your training code may appear
147+
unresponsive if the bucket has many thousands of objects. For large-scale
148+
training, stick with `S3IterableDataset` unless you specifically need random
149+
access or a finite `len(dataset)` result.
150+
151+
## 5. Run training jobs
152+
153+
By default, iterating over the S3 dataset returns an object representing each S3
154+
file (e.g. an S3 reader or data wrapper). You’ll typically want to transform the
155+
raw S3 object data into a usable format (e.g. a PyTorch tensor) before it enters
156+
your model. The S3 connector allows you to provide a `transform` function when
157+
creating the dataset – this function takes an `S3Reader` (a file-like object for
158+
the S3 object) and should return the data in tensor form for training.
159+
160+
For example, if your Tigris bucket stores images (and perhaps the directory
161+
structure encodes labels), you can define a transform that reads the image bytes
162+
and converts them to a tensor:
163+
164+
```py
165+
from PIL import Image
166+
import io
167+
import torchvision.transforms as T
168+
169+
# Define a PyTorch transformation pipeline (adjust as needed for your data)
170+
transform_pipeline = T.Compose([
171+
T.Resize((224, 224)), # e.g. resize images to 224x224
172+
T.ToTensor(), # convert PIL Image to torch.FloatTensor (C x H x W)
173+
T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) # example normalization
174+
])
175+
176+
def obj_to_tensor(obj):
177+
# Read the object content into memory
178+
byte_data = obj.read()
179+
# Open as an image (for binary image data)
180+
image = Image.open(io.BytesIO(byte_data)).convert("RGB")
181+
tensor = transform_pipeline(image)
182+
# (Optional) derive label from the S3 key if applicable
183+
key_path = obj.key # e.g. "train/images/7/img123.png"
184+
# Assuming the directory name is the label (e.g. "7" for class 7):
185+
label_str = key_path.split("/")[1] # "7" in this example
186+
label = int(label_str) if label_str.isdigit() else label_str
187+
return tensor, label
188+
```
189+
190+
This `obj_to_tensor` function does the following: it reads the object’s bytes
191+
(e.g. an image file), converts them to a PIL image, applies a series of
192+
torchvision transforms (resize, tensor conversion, normalization), and then
193+
parses the filename or path to get a label. We return a tuple `(tensor, label)`
194+
for each sample. You could also return just the tensor (and handle labels
195+
separately) depending on your use case.
196+
197+
Now, update the dataset to use this transform. We can either pass it during
198+
creation or set it afterward. It’s easiest to pass it in the `from_prefix` call:
199+
200+
```py
201+
dataset = S3IterableDataset.from_prefix(
202+
dataset_uri,
203+
region="auto",
204+
endpoint="https://t3.storage.dev",
205+
transform=obj_to_tensor, # apply our custom transform to each S3 object
206+
enable_sharding=True,
207+
s3client_config=cfg
208+
)
209+
```
210+
211+
With this transform in place, iterating over dataset will yield ready-to-use
212+
data. In our example, each iteration gives `(image_tensor, label)` pairs. Under
213+
the hood, the connector will open a stream for each object and pass an
214+
`S3Reader` to your transform, which then reads and processes the data. This
215+
keeps memory usage in check by not loading more than one object at a time per
216+
worker (unless you increase parallelism via multiple workers).
217+
218+
You can customize the transform for different data formats:
219+
220+
- For example, if your objects are `.pt` or `.pth` files containing tensors,
221+
your transform might use `torch.load(obj)` directly.
222+
- If they are CSV or text data, you could read `obj.read().decode('utf-8')` and
223+
parse lines.
224+
- If your data is already in a numpy format (e.g. `.npy`), use `np.frombuffer`
225+
on the bytes, etc.
226+
227+
The key is that the transform should convert the raw bytes/stream into the model
228+
input (and target) you need.
229+
230+
With the `S3IterableDataset` prepared, you can wrap it in a PyTorch `DataLoader`
231+
to batch data and feed it into your training loop. Streaming from S3 introduces
232+
some considerations for efficient GPU training:
233+
234+
**DataLoader Setup**: Use an appropriate batch size and number of worker
235+
processes to balance throughput and memory:
236+
237+
```py
238+
import torch
239+
from torch.utils.data import DataLoader
240+
241+
batch_size = 32
242+
num_workers = 4
243+
244+
loader = DataLoader(
245+
dataset,
246+
batch_size=batch_size,
247+
num_workers=num_workers,
248+
pin_memory=True, # use pinned memory for faster host-to-GPU transfers
249+
persistent_workers=True # keep workers alive between epochs (if running multiple epochs)
250+
# shuffle=False # Shuffle is generally not supported for IterableDataset
251+
)
252+
```
253+
254+
A few best practices are illustrated above:
255+
256+
- **Multiple Workers:** By using `num_workers > 0`, you allow multiple
257+
background processes to fetch data from S3 in parallel. With
258+
`enable_sharding=True` set on the dataset, each worker will get a distinct
259+
subset of the data (no duplicate processing). For example, with 4 workers each
260+
will stream roughly 1/4 of the dataset. This parallelism can significantly
261+
improve throughput, as each worker opens its own S3 connections.
262+
- **Batch Size:** Adjust `batch_size` based on your data size and GPU memory.
263+
Each worker will load items for a batch. The `DataLoader` will concatenate
264+
them into a single batch before yielding it. Ensure the batch is large enough
265+
to utilize GPU efficiently, but not so large that the GPU runs out of memory
266+
or that data loading becomes a bottleneck.
267+
- **Pinned Memory:** Setting `pin_memory=True` is recommended when transferring
268+
data to CUDA. It allows DataLoader workers to allocate tensors in page-locked
269+
memory, which accelerates the copy from host to GPU. In your training loop,
270+
you can then use `non_blocking=True` when calling `.to(device)` to further
271+
speed up transfers.
272+
- **Persistent Workers:** By enabling `persistent_workers=True`, the worker
273+
processes will not be shut down after one epoch. This avoids the overhead of
274+
spawning processes for each epoch, which is beneficial in a streaming scenario
275+
(especially if each epoch still needs to scan a large dataset).
276+
- **Prefetching:** By default, each worker will preload a couple of batches
277+
(`prefetch_factor=2` by default). You can tune this (e.g., increase it to 4)
278+
if you find your GPU waiting on data, but note that prefetching too many
279+
batches may consume extra memory.
280+
281+
Now, consider how to send data to the GPU in the training loop. Assuming your
282+
transform returned `(data, label)` pairs as in our example, a training loop
283+
might look like:
284+
285+
```py
286+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
287+
288+
model = ... # your model
289+
model.to(device)
290+
optimizer = ...
291+
criterion = ...
292+
293+
model.train()
294+
for epoch in range(num_epochs):
295+
for batch_idx, (images, labels) in enumerate(loader):
296+
# Move data to GPU
297+
images = images.to(device, non_blocking=True)
298+
labels = labels.to(device, non_blocking=True)
299+
300+
# Forward pass
301+
outputs = model(images)
302+
loss = criterion(outputs, labels)
303+
304+
# Backprop and optimize
305+
optimizer.zero_grad()
306+
loss.backward()
307+
optimizer.step()
308+
309+
if batch_idx % 50 == 0:
310+
print(f"Epoch {epoch} Batch {batch_idx}: Loss = {loss.item()}")
311+
```
312+
313+
A few things to note in this loop:
314+
315+
- We use `non_blocking=True` along with `pin_memory=True` (set in `DataLoader`)
316+
for faster GPU transfers.
317+
- Each iteration fetches a batch of data from the S3IterableDataset. Under the
318+
hood, each sample’s data was streamed directly from Tigris when the DataLoader
319+
worker invoked our transform. This means your CPU workers might still be
320+
reading from the network while your GPU is busy – which is fine and helps
321+
overlap I/O and compute.
322+
- **Sharding in effect**: Because we set enable_sharding=True, each worker only
323+
iterates over a portion of the dataset. This prevents duplicate data across
324+
workers. Make sure not to manually shuffle or reseed the IterableDataset in a
325+
way that breaks this – rely on the connector’s sharding. (If you need
326+
full-data shuffling, you would use a map-style dataset or implement a custom
327+
shuffle buffer, since pure streaming IterableDatasets generally don’t support
328+
a global shuffle.)
329+
330+
**Memory and Throughput Considerations**: The S3 connector is optimized to use
331+
multi-part downloads for large objects. By default it uses an 8 MiB part size
332+
for transfers, meaning it downloads data in 8MB chunks (and can do so in
333+
parallel threads for a single object to meet the throughput target). You can
334+
tune this via S3ClientConfig if needed – for example, using a larger part_size
335+
for very large files or adjusting throughput_target_gbps. In practice, the
336+
defaults (8 MiB parts, aiming for ~10 Gbps) work well for most scenarios. If you
337+
observe memory spikes, ensure you're not inadvertently reading too much data per
338+
sample (e.g., loading a huge object entirely into memory if you only need part
339+
of it). In such cases, you could use a range-based reader via
340+
`reader_constructor=S3ReaderConstructor.range_based()` to stream only needed
341+
byte ranges instead of full objects – an advanced technique that can save memory
342+
for extremely large objects.
343+
344+
Finally, monitor your CPU and network utilization. If the GPU is underutilized
345+
(idle waiting for data), you can try increasing num_workers (to fetch more data
346+
in parallel) or increasing prefetch_factor. If the CPU or network is saturated,
347+
you might reduce num_workers or batch size. The goal is to keep the GPU fed with
348+
data without exhausting the system resources.

sidebars.js

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ const sidebars = {
192192
},
193193
],
194194
},
195+
{
196+
type: "doc",
197+
label: "PyTorch",
198+
id: "quickstarts/pytorch",
199+
},
195200
{
196201
type: "category",
197202
label: "SkyPilot",

0 commit comments

Comments
 (0)