-
Notifications
You must be signed in to change notification settings - Fork 65
Support >2G model export - alternative implementation | torchlib(feat) #1004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@BowenBao Speed seems to be reasonably good. There's errors: |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1004 +/- ##
==========================================
+ Coverage 78.64% 78.65% +0.01%
==========================================
Files 118 118
Lines 15441 15435 -6
Branches 2424 2422 -2
==========================================
- Hits 12144 12141 -3
+ Misses 2899 2896 -3
Partials 398 398 ☔ View full report in Codecov by Sentry. |
Adding back infer_shapes and check_model made tests pass locally. |
Let me test again |
Nice, thanks! I updated this PR for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
Needs validation with real models |
@BowenBao should we move this to under a flag? |
) | ||
tensor_proto = onnx.helper.make_tensor( | ||
name=name, | ||
data_type=_type_utils.JitScalarType.from_dtype(tensor.dtype).onnx_type(), |
Check failure
Code scanning / lintrunner
PYLINT/E0602
) | ||
tensor_proto = onnx.helper.make_tensor( | ||
name=name, | ||
data_type=_type_utils.JitScalarType.from_dtype(tensor.dtype).onnx_type(), |
Check failure
Code scanning / lintrunner
MYPY/name-defined
) | ||
tensor_proto = onnx.helper.make_tensor( | ||
name=name, | ||
data_type=_type_utils.JitScalarType.from_dtype(tensor.dtype).onnx_type(), |
Check failure
Code scanning / lintrunner
RUFF/F821
@@ -378,12 +374,38 @@ def eval_function( # type: ignore[override] | |||
return self._graph.add_function_call(function, inputs, attributes) | |||
|
|||
|
|||
def _add_initializers( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fyi @kunal-vaishnavi creating onnx tensorproto from torch tensor dataptr
Support >2G model export - alternative implementation where we build initializers ourselves. This is a follow up of #1003. We now build the TensorProto ourselves directly from PyTorch tensors. This circumvents torchscript _export_onnx's limitation of 2G protobuf serialization and additional serialization, because we are now keeping everything in memory.
2G model is now no longer a special case because we add initializers in a separate step.
TODO