Skip to content

Error when converting to AppleCider pytorch model to ONNX #511

@drewoldag

Description

@drewoldag

There are at least 2 issues here, the first is easy to addres, the second is not as clear to me.

These lines need to be updated: https://github.com/lincc-frameworks/hyrax/blob/main/src/hyrax/verbs/train.py#L121-L124
It should be: sample = model.to_tensor(next(iter(train_data_loader)))

The other error that I'm seeing is copied below. It's not immediately clear to me what's going on here. For some additional context, the output of to_tensor for the model class is a tuple with 3 elements. And the forward method expects a tuple as input with three elements.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], [line 1](vscode-notebook-cell:?execution_count=3&line=1)
----> [1](vscode-notebook-cell:?execution_count=3&line=1) m = h.train()

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/verbs/train.py:123, in Train.run(self)
    118 sample = model.to_tensor(next(iter(train_data_loader)))
    119 # if isinstance(batch_sample, dict):
    120 #     batch_sample = model.to_tensor(batch_sample)
    121 # sample = batch_sample[0] if isinstance(batch_sample, (list, tuple)) else batch_sample
--> [123](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/verbs/train.py:123) export_to_onnx(model, sample, config, context)
    125 return model

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/model_exporters.py:36, in export_to_onnx(model, sample, config, ctx)
     34 sample_out = None
     35 if ctx["ml_framework"] == "pytorch":
---> [36](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/model_exporters.py:36)     sample, sample_out = _export_pytorch_to_onnx(model, sample, onnx_output_filepath, onnx_opset_version)
     37 else:
     38     logger.warning("No ONNX export implementation for the given ML framework.")

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/model_exporters.py:80, in _export_pytorch_to_onnx(model, sample, output_filepath, opset_version)
     77 sample_out = model(sample)
     79 # export the model to ONNX format
---> [80](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/model_exporters.py:80) export(
     81     model,
     82     sample,
     83     output_filepath,
     84     opset_version=opset_version,
     85     input_names=["input"],
     86     output_names=["output"],
     87     dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
     88 )
     90 # return the input sample as a numpy array and the output of the sample run
     91 # through the model as numpy array
     92 return sample.numpy(), sample_out.detach().numpy()

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/__init__.py:424, in export(model, args, f, kwargs, export_params, verbose, input_names, output_names, opset_version, dynamic_axes, keep_initializers_as_inputs, dynamo, external_data, dynamic_shapes, custom_translation_table, report, optimize, verify, profile, dump_exported_program, artifacts_dir, fallback, training, operator_export_type, do_constant_folding, custom_opsets, export_modules_as_functions, autograd_inlining)
    418 if dynamic_shapes:
    419     raise ValueError(
    420         "The exporter only supports dynamic shapes "
    421         "through parameter dynamic_axes when dynamo=False."
    422     )
--> [424](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/__init__.py:424) export(
    425     model,
    426     args,
    427     f,  # type: ignore[arg-type]
    428     kwargs=kwargs,
    429     export_params=export_params,
    430     verbose=verbose is True,
    431     input_names=input_names,
    432     output_names=output_names,
    433     opset_version=opset_version,
    434     dynamic_axes=dynamic_axes,
    435     keep_initializers_as_inputs=keep_initializers_as_inputs,
    436     training=training,
    437     operator_export_type=operator_export_type,
    438     do_constant_folding=do_constant_folding,
    439     custom_opsets=custom_opsets,
    440     export_modules_as_functions=export_modules_as_functions,
    441     autograd_inlining=autograd_inlining,
    442 )
    443 return None

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:522, in export(model, args, f, kwargs, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
    519 if kwargs is not None:
    520     args = args + (kwargs,)
--> [522](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:522) _export(
    523     model,
    524     args,
    525     f,
    526     export_params,
    527     verbose,
    528     training,
    529     input_names,
    530     output_names,
    531     operator_export_type=operator_export_type,
    532     opset_version=opset_version,
    533     do_constant_folding=do_constant_folding,
    534     dynamic_axes=dynamic_axes,
    535     keep_initializers_as_inputs=keep_initializers_as_inputs,
    536     custom_opsets=custom_opsets,
    537     export_modules_as_functions=export_modules_as_functions,
    538     autograd_inlining=autograd_inlining,
    539 )
    541 return None

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:1457, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
   1454     dynamic_axes = {}
   1455 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> [1457](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:1457) graph, params_dict, torch_out = _model_to_graph(
   1458     model,
   1459     args,
   1460     verbose,
   1461     input_names,
   1462     output_names,
   1463     operator_export_type,
   1464     val_do_constant_folding,
   1465     fixed_batch_size=fixed_batch_size,
   1466     training=training,
   1467     dynamic_axes=dynamic_axes,
   1468 )
   1470 if custom_opsets is None:
   1471     custom_opsets = {}

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:1080, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   1077     args = (args,)
   1079 model = _pre_trace_quant_model(model, args)
-> [1080](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:1080) graph, params, torch_out, module = _create_jit_graph(model, args)
   1081 params_dict = _get_named_param_dict(graph, params)
   1083 try:

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:964, in _create_jit_graph(model, args)
    959     graph = _C._propagate_and_assign_input_shapes(
    960         graph, flattened_args, param_count_list, False, False
    961     )
    962     return graph, params, torch_out, None
--> [964](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:964) graph, torch_out = _trace_and_get_graph_from_model(model, args)
    965 _C._jit_pass_onnx_lint(graph)
    966 state_dict = torch.jit._unique_state_dict(model)

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:871, in _trace_and_get_graph_from_model(model, args)
    869 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
    870 torch.set_autocast_cache_enabled(False)
--> [871](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:871) trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
    872     model,
    873     args,
    874     strict=False,
    875     _force_outplace=False,
    876     _return_inputs_states=True,
    877 )
    878 torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
    880 warn_on_static_input_change(inputs_states)

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:1504, in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
   1502 if not isinstance(args, tuple):
   1503     args = (args,)
-> [1504](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:1504) outs = ONNXTracedModule(
   1505     f, strict, _force_outplace, return_inputs, _return_inputs_states
   1506 )(*args, **kwargs)
   1507 return outs

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> [1773](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1773)     return self._call_impl(*args, **kwargs)

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1784](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1784)     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:138, in ONNXTracedModule.forward(self, *args)
    135     else:
    136         return tuple(out_vars)
--> [138](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:138) graph, _out = torch._C._create_graph_by_tracing(
    139     wrapper,
    140     in_vars + module_state,
    141     _create_interpreter_name_lookup_fn(),
    142     self.strict,
    143     self._force_outplace,
    144 )
    146 if self._return_inputs:
    147     return graph, outs[0], ret_inputs[0]

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:129, in ONNXTracedModule.forward.<locals>.wrapper(*args)
    127 if self._return_inputs_states:
    128     inputs_states.append(_unflatten(in_args, in_desc))
--> [129](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:129) outs.append(self.inner(*trace_inputs))
    130 if self._return_inputs_states:
    131     inputs_states[0] = (inputs_states[0], trace_inputs)

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> [1773](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1773)     return self._call_impl(*args, **kwargs)

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1784](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1784)     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1763, in Module._slow_forward(self, *input, **kwargs)
   1761         recording_scopes = False
   1762 try:
-> [1763](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1763)     result = self.forward(*input, **kwargs)
   1764 finally:
   1765     if recording_scopes:

TypeError: AstroMiNN.forward() takes 2 positional arguments but 4 were given

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions