Skip to content

Improve the compatibility dealing with large ONNX proto in ORTOptimizer and ORTQuantizer#332

Merged
JingyaHuang merged 33 commits into
mainfrom
add_onnx_export_options
Nov 28, 2022
Merged

Improve the compatibility dealing with large ONNX proto in ORTOptimizer and ORTQuantizer#332
JingyaHuang merged 33 commits into
mainfrom
add_onnx_export_options

Conversation

@JingyaHuang
Copy link
Copy Markdown
Contributor

@JingyaHuang JingyaHuang commented Aug 2, 2022

What does this PR do?

  • Add all_tensors_to_one_file option to ORTOptimizer when exporting the ONNX model.

P.S. only the ONNXModel of optimization has added the option, not quantization.

  • Refactoring the compatibility of ORTOptimizer and ORTQuantizer in cases of large ONNX proto(export path / load ONNX)
  • Update the optimization and quantization examples with the new option.

Fixes #222

Related to onnx/onnx#4394

P.S.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Aug 8, 2022

The documentation is not available anymore as the PR was closed or merged.

@JingyaHuang JingyaHuang changed the base branch from main to add-causallm-with-pkv August 11, 2022 09:17
@JingyaHuang JingyaHuang changed the base branch from add-causallm-with-pkv to main August 11, 2022 09:17
@JingyaHuang JingyaHuang changed the title Add more optional configs for exporting large ModelProto Improve the compatibility dealing with large ONNX proto in ORTOptimizer and ORTQuantizer Aug 11, 2022
@JingyaHuang JingyaHuang changed the base branch from main to add-causallm-with-pkv August 12, 2022 16:24
@JingyaHuang JingyaHuang changed the base branch from add-causallm-with-pkv to main August 12, 2022 16:24
@JingyaHuang JingyaHuang marked this pull request as ready for review August 12, 2022 17:41
Copy link
Copy Markdown
Contributor

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR @JingyaHuang!!
I just have one comment which may be unrelated to this PR.

Comment thread optimum/onnxruntime/quantization.py Outdated
@JingyaHuang JingyaHuang requested a review from regisss August 23, 2022 07:43
Copy link
Copy Markdown
Contributor

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Jingya!!

Comment thread optimum/onnxruntime/optimization.py Outdated
Comment thread optimum/onnxruntime/quantization.py
Comment thread examples/onnxruntime/optimization/multiple-choice/run_swag.py
Comment thread optimum/onnxruntime/optimization.py Outdated
Copy link
Copy Markdown
Contributor Author

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing, I will put the arguments in ORTModel as suggested and make necessary changes in the examples.

Comment thread optimum/onnxruntime/optimization.py Outdated
Comment thread examples/onnxruntime/optimization/multiple-choice/run_swag.py
Copy link
Copy Markdown
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @JingyaHuang !

@JingyaHuang
Copy link
Copy Markdown
Contributor Author

Moved config(model and ORT) saving after saving the model, otherwise will get the following error as configs are saved in the folder:

RuntimeError: Output directory (results/albert_largev1_squad2) for external data is not empty.

While running the template example, I got the error message in the evaluation phase(no problem with the optimization, and models with external files are properly saved):

Traceback (most recent call last):
  File "run_qa.py", line 540, in <module>
    main()
  File "run_qa.py", line 492, in main
    metrics = compute_metrics(predictions)
  File "run_qa.py", line 460, in compute_metrics
    return metric.compute(predictions=p.predictions, references=p.label_ids)
  File "/usr/local/lib/python3.8/dist-packages/evaluate/module.py", line 444, in compute
    output = self._compute(**inputs, **compute_kwargs)
  File "/root/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--squad/b4e2dbca455821c7367faa26712f378254b69040ebaab90b64bdeb465e4a304d/squad.py", line 110, in _compute
    score = compute_score(dataset=dataset, predictions=pred_dict)
  File "/root/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--squad/b4e2dbca455821c7367faa26712f378254b69040ebaab90b64bdeb465e4a304d/compute_score.py", line 67, in compute_score
    exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
  File "/root/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--squad/b4e2dbca455821c7367faa26712f378254b69040ebaab90b64bdeb465e4a304d/compute_score.py", line 52, in metric_max_over_ground_truths
    return max(scores_for_ground_truths)
ValueError: max() arg is an empty sequence

Seems like a bug with the metric, an update might need to be done with the examples. I can take it over if you don't have bandwidth @echarlaix.

@JingyaHuang JingyaHuang merged commit 0808c8c into main Nov 28, 2022
@JingyaHuang JingyaHuang deleted the add_onnx_export_options branch November 28, 2022 21:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] ONNX optimization fails when optimizing AlbertXXL despite the weights being under 2GB

4 participants