Skip to content

Add ORTModel support for custom tasks#303

Merged
JingyaHuang merged 69 commits into
mainfrom
jingya-refactoring-ort-model
Aug 3, 2022
Merged

Add ORTModel support for custom tasks#303
JingyaHuang merged 69 commits into
mainfrom
jingya-refactoring-ort-model

Conversation

@JingyaHuang

@JingyaHuang JingyaHuang commented Jul 18, 2022

Copy link
Copy Markdown
Contributor

What does this PR do?

ORTModelForXXX.model decided valid inputs and outputs of the model's forward method, thus the creation of inputs in the forward method can be abstract, and also the outputs. This would allow the ORTModels to be more flexible.

e.g. In ORTTrainer, the evaluation includes labels as input and loss as output. With the PR, it will enable us to replace bare inference sessions with ORTModels more easily.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev

HuggingFaceDocBuilderDev commented Jul 18, 2022

Copy link
Copy Markdown

The documentation is not available anymore as the PR was closed or merged.

@regisss regisss left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this @JingyaHuang!!
Just one nit

Comment thread optimum/onnxruntime/modeling_ort.py Outdated
Comment thread optimum/onnxruntime/modeling_ort.py Outdated
Comment thread optimum/onnxruntime/modeling_ort.py Outdated
Comment thread optimum/onnxruntime/modeling_ort.py Outdated
Comment thread optimum/onnxruntime/modeling_ort.py Outdated
Comment thread optimum/onnxruntime/modeling_ort.py Outdated

@philschmid philschmid left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we should make this change for enabling something, which the classes aren't designed for.
You said

e.g. In ORTTrainer, the evaluation includes labels as input and loss as output. With the PR, it will enable us to replace bare inference sessions with ORTModels more easily.

This sounds more that we should use evaluate and pipeline or the ORTModel in the trainer with the post-processing outside of the model.

Comment thread optimum/onnxruntime/modeling_ort.py Outdated
Comment thread optimum/onnxruntime/modeling_ort.py Outdated
Comment thread optimum/onnxruntime/modeling_ort.py Outdated
@philschmid

philschmid commented Jul 18, 2022

Copy link
Copy Markdown
Contributor

I still do not understand the purpose of this change.

As mentioned before the ORTModelForXX were introduced for inference and not for training. The idea was to be able to have API compatible Model classes, which can be used with pipelines without the need to re-write the pre - & post-processing.
Also, the idea is to add those inference model classes to other packages, e.g optimum-intel.
Additionally, we are looking into removing the copying and adding IOBindings to reduce latency in the future.

The changes you suggest:

  • slow down the inference
  • add a lot of complex dynamic code -> which we tried to exclude, that's why we have several ORTModelForXX classes rather than one
  • Add support for something ORTTrainier training specific.

The question I have is:

  • Is this change needed for the ORTTrainer? how are we currently doing it?
  • What is the benefit for the customer?

@philschmid philschmid left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test and then we should be good. Good idea! ✅ And if some use case emerges out of it we can add new task-specific model classes

Comment thread optimum/onnxruntime/modeling_ort.py
@JingyaHuang

Copy link
Copy Markdown
Contributor Author

I still do not understand the purpose of this change.

As mentioned before the ORTModelForXX were introduced for inference and not for training. The idea was to be able to have API compatible Model classes, which can be used with pipelines without the need to re-write the pre - & post-processing. Also, the idea is to add those inference model classes to other packages, e.g optimum-intel. Additionally, we are looking into removing the copying and adding IOBindings to reduce latency in the future.

The changes you suggest:

  • slow down the inference
  • add a lot of complex dynamic code -> which we tried to exclude, that's why we have several ORTModelForXX classes rather than one
  • Add support for something ORTTrainier training specific.

The question I have is:

  • Is this change needed for the ORTTrainer? how are we currently doing it?
  • What is the benefit for the customer?

Hi @philschmid, sorry for the late reply. I re-drafted the code, indeed it shouldn't be in other task-specific models as it will slow them down. The basic idea behind the PR is to leave some flexibility to users, it's like a fallback so that when they are using a more customized model they can still be able to benefit from the ORTModel foundation with a small sacrifice of speed.
ORTTrainer.train() is independent of the PR, but for the inference of ORTTrainer the evaluate() and predict(), I am using directly InferenceSession right now and I will replace it with ORTModel, for the predict it is pretty straight forward, but for the evaluate the model include loss thus I need something more customized, and things like ORTModelForCustomTasks shall be helpful.

@JingyaHuang JingyaHuang changed the title Refactoring ort model inputs and outputs Add ORTModel support for custom tasks Jul 22, 2022
@JingyaHuang JingyaHuang changed the base branch from main to doc-builder-habana-test August 1, 2022 20:05
@JingyaHuang JingyaHuang changed the base branch from doc-builder-habana-test to main August 1, 2022 20:06
@JingyaHuang JingyaHuang merged commit d3c0b75 into main Aug 3, 2022
@JingyaHuang JingyaHuang deleted the jingya-refactoring-ort-model branch August 3, 2022 09:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants