-
Notifications
You must be signed in to change notification settings - Fork 5
Error when converting to AppleCider pytorch model to ONNX #511
Copy link
Copy link
Closed
Description
There are at least 2 issues here, the first is easy to addres, the second is not as clear to me.
These lines need to be updated: https://github.com/lincc-frameworks/hyrax/blob/main/src/hyrax/verbs/train.py#L121-L124
It should be: sample = model.to_tensor(next(iter(train_data_loader)))
The other error that I'm seeing is copied below. It's not immediately clear to me what's going on here. For some additional context, the output of to_tensor for the model class is a tuple with 3 elements. And the forward method expects a tuple as input with three elements.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[3], [line 1](vscode-notebook-cell:?execution_count=3&line=1)
----> [1](vscode-notebook-cell:?execution_count=3&line=1) m = h.train()
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/verbs/train.py:123, in Train.run(self)
118 sample = model.to_tensor(next(iter(train_data_loader)))
119 # if isinstance(batch_sample, dict):
120 # batch_sample = model.to_tensor(batch_sample)
121 # sample = batch_sample[0] if isinstance(batch_sample, (list, tuple)) else batch_sample
--> [123](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/verbs/train.py:123) export_to_onnx(model, sample, config, context)
125 return model
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/model_exporters.py:36, in export_to_onnx(model, sample, config, ctx)
34 sample_out = None
35 if ctx["ml_framework"] == "pytorch":
---> [36](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/model_exporters.py:36) sample, sample_out = _export_pytorch_to_onnx(model, sample, onnx_output_filepath, onnx_opset_version)
37 else:
38 logger.warning("No ONNX export implementation for the given ML framework.")
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/model_exporters.py:80, in _export_pytorch_to_onnx(model, sample, output_filepath, opset_version)
77 sample_out = model(sample)
79 # export the model to ONNX format
---> [80](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/hyrax/model_exporters.py:80) export(
81 model,
82 sample,
83 output_filepath,
84 opset_version=opset_version,
85 input_names=["input"],
86 output_names=["output"],
87 dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
88 )
90 # return the input sample as a numpy array and the output of the sample run
91 # through the model as numpy array
92 return sample.numpy(), sample_out.detach().numpy()
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/__init__.py:424, in export(model, args, f, kwargs, export_params, verbose, input_names, output_names, opset_version, dynamic_axes, keep_initializers_as_inputs, dynamo, external_data, dynamic_shapes, custom_translation_table, report, optimize, verify, profile, dump_exported_program, artifacts_dir, fallback, training, operator_export_type, do_constant_folding, custom_opsets, export_modules_as_functions, autograd_inlining)
418 if dynamic_shapes:
419 raise ValueError(
420 "The exporter only supports dynamic shapes "
421 "through parameter dynamic_axes when dynamo=False."
422 )
--> [424](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/__init__.py:424) export(
425 model,
426 args,
427 f, # type: ignore[arg-type]
428 kwargs=kwargs,
429 export_params=export_params,
430 verbose=verbose is True,
431 input_names=input_names,
432 output_names=output_names,
433 opset_version=opset_version,
434 dynamic_axes=dynamic_axes,
435 keep_initializers_as_inputs=keep_initializers_as_inputs,
436 training=training,
437 operator_export_type=operator_export_type,
438 do_constant_folding=do_constant_folding,
439 custom_opsets=custom_opsets,
440 export_modules_as_functions=export_modules_as_functions,
441 autograd_inlining=autograd_inlining,
442 )
443 return None
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:522, in export(model, args, f, kwargs, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
519 if kwargs is not None:
520 args = args + (kwargs,)
--> [522](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:522) _export(
523 model,
524 args,
525 f,
526 export_params,
527 verbose,
528 training,
529 input_names,
530 output_names,
531 operator_export_type=operator_export_type,
532 opset_version=opset_version,
533 do_constant_folding=do_constant_folding,
534 dynamic_axes=dynamic_axes,
535 keep_initializers_as_inputs=keep_initializers_as_inputs,
536 custom_opsets=custom_opsets,
537 export_modules_as_functions=export_modules_as_functions,
538 autograd_inlining=autograd_inlining,
539 )
541 return None
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:1457, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
1454 dynamic_axes = {}
1455 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> [1457](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:1457) graph, params_dict, torch_out = _model_to_graph(
1458 model,
1459 args,
1460 verbose,
1461 input_names,
1462 output_names,
1463 operator_export_type,
1464 val_do_constant_folding,
1465 fixed_batch_size=fixed_batch_size,
1466 training=training,
1467 dynamic_axes=dynamic_axes,
1468 )
1470 if custom_opsets is None:
1471 custom_opsets = {}
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:1080, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
1077 args = (args,)
1079 model = _pre_trace_quant_model(model, args)
-> [1080](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:1080) graph, params, torch_out, module = _create_jit_graph(model, args)
1081 params_dict = _get_named_param_dict(graph, params)
1083 try:
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:964, in _create_jit_graph(model, args)
959 graph = _C._propagate_and_assign_input_shapes(
960 graph, flattened_args, param_count_list, False, False
961 )
962 return graph, params, torch_out, None
--> [964](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:964) graph, torch_out = _trace_and_get_graph_from_model(model, args)
965 _C._jit_pass_onnx_lint(graph)
966 state_dict = torch.jit._unique_state_dict(model)
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:871, in _trace_and_get_graph_from_model(model, args)
869 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
870 torch.set_autocast_cache_enabled(False)
--> [871](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/onnx/utils.py:871) trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
872 model,
873 args,
874 strict=False,
875 _force_outplace=False,
876 _return_inputs_states=True,
877 )
878 torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
880 warn_on_static_input_change(inputs_states)
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:1504, in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
1502 if not isinstance(args, tuple):
1503 args = (args,)
-> [1504](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:1504) outs = ONNXTracedModule(
1505 f, strict, _force_outplace, return_inputs, _return_inputs_states
1506 )(*args, **kwargs)
1507 return outs
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> [1773](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1773) return self._call_impl(*args, **kwargs)
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> [1784](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1784) return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:138, in ONNXTracedModule.forward(self, *args)
135 else:
136 return tuple(out_vars)
--> [138](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:138) graph, _out = torch._C._create_graph_by_tracing(
139 wrapper,
140 in_vars + module_state,
141 _create_interpreter_name_lookup_fn(),
142 self.strict,
143 self._force_outplace,
144 )
146 if self._return_inputs:
147 return graph, outs[0], ret_inputs[0]
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:129, in ONNXTracedModule.forward.<locals>.wrapper(*args)
127 if self._return_inputs_states:
128 inputs_states.append(_unflatten(in_args, in_desc))
--> [129](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/jit/_trace.py:129) outs.append(self.inner(*trace_inputs))
130 if self._return_inputs_states:
131 inputs_states[0] = (inputs_states[0], trace_inputs)
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> [1773](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1773) return self._call_impl(*args, **kwargs)
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> [1784](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1784) return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File ~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1763, in Module._slow_forward(self, *input, **kwargs)
1761 recording_scopes = False
1762 try:
-> [1763](https://file+.vscode-resource.vscode-cdn.net/Users/drew/code/applecider/notebooks/testing/~/opt/miniconda3/envs/applecider/lib/python3.12/site-packages/torch/nn/modules/module.py:1763) result = self.forward(*input, **kwargs)
1764 finally:
1765 if recording_scopes:
TypeError: AstroMiNN.forward() takes 2 positional arguments but 4 were given
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels