Skip to content

Commit 9fc5960

Browse files
justinchubyCopilot
andauthored
Implement save_model and tensor sharding (#39)
This pull request introduces support for sharding large ONNX model weights into multiple safetensors files, adds a new high-level `save_model` API, and provides comprehensive tests for these new features. The main focus is on enabling the saving of large models by splitting their weights into manageable chunks, improving usability and scalability. <img width="950" height="43" alt="image" src="https://github.com/user-attachments/assets/50f055e2-7408-42aa-9191-6236ba2fcee9" /> **Major new features and improvements:** ### Sharding and Size Parsing Functionality - Added logic to shard tensors into multiple safetensors files using a new `max_shard_size` parameter in `save_file`. This includes helper functions for parsing human-readable size strings (e.g., "5GB", "100MB") and generating consistent shard filenames. When sharding occurs, an index file is also created to map tensors to their respective shards. [[1]](diffhunk://#diff-3eb7f37ead9b460c5ee7867f66123eca4dfd7a2b04406c2df2f5e9df27dcf8f9R9) [[2]](diffhunk://#diff-3eb7f37ead9b460c5ee7867f66123eca4dfd7a2b04406c2df2f5e9df27dcf8f9R70-R179) [[3]](diffhunk://#diff-3eb7f37ead9b460c5ee7867f66123eca4dfd7a2b04406c2df2f5e9df27dcf8f9R370) [[4]](diffhunk://#diff-3eb7f37ead9b460c5ee7867f66123eca4dfd7a2b04406c2df2f5e9df27dcf8f9R382-R384) [[5]](diffhunk://#diff-3eb7f37ead9b460c5ee7867f66123eca4dfd7a2b04406c2df2f5e9df27dcf8f9R400-R401) [[6]](diffhunk://#diff-3eb7f37ead9b460c5ee7867f66123eca4dfd7a2b04406c2df2f5e9df27dcf8f9R421-R477) ### New API for Model Saving - Introduced a new `save_model` function to the public API, allowing users to save an ONNX model and its weights (optionally sharded) in one call. This function enforces that the external data file uses the `.safetensors` extension and supports the new sharding mechanism. [[1]](diffhunk://#diff-a04ee6152b4b4bbfa13c0ca8e9abf7b92c268d4226db8a605e04fbb6456c6311R10) [[2]](diffhunk://#diff-a04ee6152b4b4bbfa13c0ca8e9abf7b92c268d4226db8a605e04fbb6456c6311R20) [[3]](diffhunk://#diff-3eb7f37ead9b460c5ee7867f66123eca4dfd7a2b04406c2df2f5e9df27dcf8f9R488-R537) ### Testing and Validation - Added extensive unit tests to cover the new `save_model` API, sharding logic, size string parsing, and filename generation. Tests verify correct file outputs, error handling, and that sharding/indexing works as expected for both ONNX and IR models. [[1]](diffhunk://#diff-bc794a7949109ddb39c25b9f90d153b28233998579a47a0add0381de6a72c20aR114-R183) [[2]](diffhunk://#diff-bc794a7949109ddb39c25b9f90d153b28233998579a47a0add0381de6a72c20aR306-R510) These changes significantly enhance the library's ability to handle large models and improve the developer experience with a more user-friendly API. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 686b1dd commit 9fc5960

8 files changed

Lines changed: 1130 additions & 39 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,5 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161+
162+
.DS_Store

README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,55 @@ model_with_external_data = onnx_safetensors.save_file(model, data_path, base_dir
7575
onnx.save(model_with_external_data, os.path.join(base_dir, "model_using_safetensors.onnx"))
7676
```
7777

78+
### Save an ONNX model with safetensors weights
79+
80+
The `save_model` function is a convenient way to save both the ONNX model and its weights to separate files:
81+
82+
```python
83+
import onnx_safetensors
84+
85+
# Provide your ONNX model here
86+
model: onnx.ModelProto
87+
88+
# Save model and weights in one step
89+
# This creates model.onnx and model.safetensors
90+
onnx_safetensors.save_model(model, "model.onnx")
91+
92+
# You can also specify a custom name for the weights file
93+
onnx_safetensors.save_model(model, "model.onnx", external_data="weights.safetensors")
94+
```
95+
96+
### Shard large models
97+
98+
For large models, you can automatically shard the weights across multiple safetensors files:
99+
100+
```python
101+
import onnx_safetensors
102+
103+
# Provide your ONNX model here
104+
model: onnx.ModelProto
105+
106+
# Shard the model into multiple files (e.g., 5GB per shard)
107+
# This creates:
108+
# - model.onnx
109+
# - model-00001-of-00003.safetensors
110+
# - model-00002-of-00003.safetensors
111+
# - model-00003-of-00003.safetensors
112+
# - model.safetensors.index.json (index file mapping tensors to shards)
113+
onnx_safetensors.save_model(model, "model.onnx", max_shard_size="5GB")
114+
115+
# You can also use save_file with sharding
116+
onnx_safetensors.save_file(
117+
model,
118+
"weights.safetensors",
119+
base_dir="path/to/save",
120+
max_shard_size="5GB"
121+
)
122+
```
123+
124+
The sharding format is compatible with the Hugging Face transformers library, making it easy to share and load models across different frameworks.
125+
78126
## Examples
79127

80128
- [Tutorial notebook](examples/tutorial.ipynb)
129+
- [save_model and sharding examples](examples/save_model_sharding.py)

examples/save_model_sharding.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""Example demonstrating save_model and model sharding functionality.
2+
3+
This example shows how to:
4+
1. Save an ONNX model with safetensors weights using save_model
5+
2. Shard large models across multiple safetensors files
6+
3. Load and verify sharded models with ONNX Runtime
7+
"""
8+
9+
import glob
10+
import json
11+
import os
12+
13+
import numpy as np
14+
import onnx
15+
import onnx.helper
16+
import onnx.numpy_helper
17+
import onnxruntime as ort
18+
19+
import onnx_safetensors
20+
21+
22+
def create_example_model(large: bool = False) -> onnx.ModelProto:
23+
"""Create an example ONNX model for demonstration.
24+
25+
Args:
26+
large: If True, creates a larger model to demonstrate sharding.
27+
28+
Returns:
29+
An ONNX model.
30+
"""
31+
if large:
32+
# Create a larger model with multiple weight matrices to demonstrate sharding
33+
weights1 = np.random.randn(1000, 1000).astype(np.float32) # ~4MB
34+
weights2 = np.random.randn(1000, 2000).astype(np.float32) # ~8MB
35+
weights3 = np.random.randn(2000, 1000).astype(np.float32) # ~8MB
36+
37+
graph = onnx.helper.make_graph(
38+
[
39+
onnx.helper.make_node("MatMul", ["input", "weights1"], ["temp1"]),
40+
onnx.helper.make_node("MatMul", ["temp1", "weights2"], ["temp2"]),
41+
onnx.helper.make_node("MatMul", ["temp2", "weights3"], ["output"]),
42+
],
43+
"large_model",
44+
inputs=[
45+
onnx.helper.make_tensor_value_info(
46+
"input", onnx.TensorProto.FLOAT, [1, 1000]
47+
),
48+
],
49+
outputs=[
50+
onnx.helper.make_tensor_value_info(
51+
"output", onnx.TensorProto.FLOAT, [1, 1000]
52+
),
53+
],
54+
initializer=[
55+
onnx.numpy_helper.from_array(weights1, name="weights1"),
56+
onnx.numpy_helper.from_array(weights2, name="weights2"),
57+
onnx.numpy_helper.from_array(weights3, name="weights3"),
58+
],
59+
)
60+
else:
61+
# Create a simple model
62+
weights = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
63+
64+
graph = onnx.helper.make_graph(
65+
[
66+
onnx.helper.make_node("Add", ["input", "weights"], ["output"]),
67+
],
68+
"simple_model",
69+
inputs=[
70+
onnx.helper.make_tensor_value_info(
71+
"input", onnx.TensorProto.FLOAT, [2, 3]
72+
),
73+
],
74+
outputs=[
75+
onnx.helper.make_tensor_value_info(
76+
"output", onnx.TensorProto.FLOAT, [2, 3]
77+
),
78+
],
79+
initializer=[onnx.numpy_helper.from_array(weights, name="weights")],
80+
)
81+
82+
model = onnx.helper.make_model(
83+
graph, opset_imports=[onnx.helper.make_opsetid("", 14)], ir_version=10
84+
)
85+
return model
86+
87+
88+
def example_basic_save_model():
89+
"""Example 1: Basic usage of save_model."""
90+
print("Example 1: Basic save_model usage")
91+
print("=" * 50)
92+
93+
# Create a simple model
94+
model = create_example_model(large=False)
95+
96+
# Save model and weights
97+
# This creates:
98+
# - simple_model.onnx (ONNX model file)
99+
# - simple_model.safetensors (weights file)
100+
onnx_safetensors.save_model(model, "simple_model.onnx")
101+
102+
print("✓ Saved simple_model.onnx and simple_model.safetensors")
103+
104+
# Load and verify the model with ONNX Runtime
105+
sess = ort.InferenceSession("simple_model.onnx", providers=["CPUExecutionProvider"])
106+
input_data = np.ones((2, 3), dtype=np.float32)
107+
outputs = sess.run(None, {"input": input_data})
108+
109+
print("✓ Model runs successfully with ONNX Runtime")
110+
print(f" Output shape: {outputs[0].shape}")
111+
print()
112+
113+
114+
def example_custom_weights_file():
115+
"""Example 2: Specify a custom name for the weights file."""
116+
print("Example 2: Custom weights file name")
117+
print("=" * 50)
118+
119+
model = create_example_model(large=False)
120+
121+
# Save with custom weights file name
122+
# This creates:
123+
# - my_model.onnx
124+
# - custom_weights.safetensors
125+
onnx_safetensors.save_model(
126+
model, "my_model.onnx", external_data="custom_weights.safetensors"
127+
)
128+
129+
print("✓ Saved my_model.onnx with custom_weights.safetensors")
130+
print()
131+
132+
133+
def example_model_sharding():
134+
"""Example 3: Shard a large model across multiple files."""
135+
print("Example 3: Model sharding")
136+
print("=" * 50)
137+
138+
# Create a larger model
139+
model = create_example_model(large=True)
140+
141+
# Shard the model with 5MB per shard
142+
# This creates:
143+
# - large_model.onnx
144+
# - large_model-00001-of-00004.safetensors
145+
# - large_model-00002-of-00004.safetensors
146+
# - large_model-00003-of-00004.safetensors
147+
# - large_model-00004-of-00004.safetensors
148+
# - large_model.safetensors.index.json (index file)
149+
onnx_safetensors.save_model(model, "large_model.onnx", max_shard_size="5MB")
150+
151+
print("✓ Saved large_model.onnx with sharded weights")
152+
print(" Files created:")
153+
154+
# List the created shard files
155+
shard_files = sorted(glob.glob("large_model-*.safetensors"))
156+
for shard_file in shard_files:
157+
size_mb = os.path.getsize(shard_file) / (1024 * 1024)
158+
print(f" - {shard_file} ({size_mb:.2f} MB)")
159+
160+
# Check for index file
161+
if os.path.exists("large_model.safetensors.index.json"):
162+
with open("large_model.safetensors.index.json") as f:
163+
index = json.load(f)
164+
print(f" ✓ Index file created with {len(index['weight_map'])} tensors mapped")
165+
166+
# Verify the sharded model works with ONNX Runtime
167+
sess = ort.InferenceSession("large_model.onnx", providers=["CPUExecutionProvider"])
168+
input_data = np.random.randn(1, 1000).astype(np.float32)
169+
outputs = sess.run(None, {"input": input_data})
170+
171+
print("✓ Sharded model runs successfully with ONNX Runtime")
172+
print(f" Output shape: {outputs[0].shape}")
173+
print()
174+
175+
176+
def example_save_file_with_sharding():
177+
"""Example 4: Use save_file with sharding for more control."""
178+
print("Example 4: save_file with sharding")
179+
print("=" * 50)
180+
181+
model = create_example_model(large=True)
182+
183+
# Save only the weights with sharding
184+
# Note: This doesn't save the ONNX model file itself
185+
onnx_safetensors.save_file(
186+
model,
187+
"weights_only.safetensors",
188+
base_dir=".",
189+
max_shard_size="5MB",
190+
replace_data=False, # Don't modify the model
191+
)
192+
193+
print("✓ Saved sharded weights without modifying the model")
194+
195+
shard_files = sorted(glob.glob("weights_only-*.safetensors"))
196+
print(f" Created {len(shard_files)} shard files")
197+
print()
198+
199+
200+
if __name__ == "__main__":
201+
print("ONNX-Safetensors: save_model and Sharding Examples")
202+
print("=" * 50)
203+
print()
204+
205+
# Run all examples
206+
example_basic_save_model()
207+
example_custom_weights_file()
208+
example_model_sharding()
209+
example_save_file_with_sharding()
210+
211+
print("All examples completed successfully! ✓")
212+
print()
213+
print("Note: This example created several files for demonstration.")
214+
print("You can safely delete them after reviewing.")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"onnx>=1.16",
3434
"safetensors",
3535
"onnx-ir",
36+
"tqdm",
3637
]
3738

3839
[project.urls]

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ pytest-cov
33
lintrunner
44
lintrunner-adapters
55
parameterized
6+
onnxruntime

src/onnx_safetensors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"replace_tensors",
88
"save",
99
"save_file",
10+
"save_model",
1011
]
1112

1213
from onnx_safetensors._safetensors_io import (
@@ -16,6 +17,7 @@
1617
replace_tensors,
1718
save,
1819
save_file,
20+
save_model,
1921
)
2022

2123
__version__ = "1.2.0"

0 commit comments

Comments
 (0)