Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 216 additions & 1 deletion optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Union

import numpy as np
import openvino
import torch
import transformers
from huggingface_hub import model_info
from transformers import (
AutoConfig,
AutoModel,
Expand All @@ -31,6 +34,7 @@
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
PretrainedConfig,
)
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import (
Expand All @@ -45,8 +49,12 @@
)

from optimum.exporters import TasksManager
from optimum.exporters.onnx import export
from optimum.modeling_base import OptimizedModel

from .modeling_base import OVBaseModel
from .modeling_timm import TimmConfig, TimmForImageClassification, TimmOnnxConfig
from .utils import ONNX_WEIGHTS_NAME


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -497,6 +505,52 @@ class OVModelForImageClassification(OVModel):
def __init__(self, model=None, config=None, **kwargs):
super().__init__(model, config, **kwargs)

@classmethod
def from_pretrained(
cls,
model_id: Union[str, Path],
export: bool = False,
config: Optional["PretrainedConfig"] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
task: Optional[str] = None,
trust_remote_code: bool = False,
**kwargs,
):
# Fix the mismatch between timm_config and huggingface_config
if not os.path.isdir(model_id) and model_info(model_id).library_name == "timm":
return OVModelForTimm.from_timm(
model_id=model_id,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
subfolder=subfolder,
local_files_only=local_files_only,
task=task,
trust_remote_code=trust_remote_code,
**kwargs,
)
else:
return super().from_pretrained(
model_id=model_id,
config=config,
export=export,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
subfolder=subfolder,
local_files_only=local_files_only,
task=task,
trust_remote_code=trust_remote_code,
**kwargs,
)

@add_start_docstrings_to_model_forward(
IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
+ IMAGE_CLASSIFICATION_EXAMPLE.format(
Expand Down Expand Up @@ -526,6 +580,167 @@ def forward(
return ImageClassifierOutput(logits=logits)


class OVModelForTimm(OVModel):
@classmethod
def _from_transformers(
cls,
model_id: str,
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
task: Optional[str] = None,
trust_remote_code: bool = False,
**kwargs,
):
task = task or cls.export_feature

model = TimmForImageClassification.from_pretrained(model_id, **kwargs)
onnx_config = TimmOnnxConfig(model.config)
save_dir = TemporaryDirectory()

with TemporaryDirectory() as save_dir:
save_dir_path = Path(save_dir)
export(
model=model,
config=onnx_config,
opset=onnx_config.DEFAULT_TIMM_ONNX_OPSET,
output=save_dir_path / ONNX_WEIGHTS_NAME,
)

return cls._from_pretrained(
model_id=save_dir_path,
config=config,
from_onnx=True,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
**kwargs,
)

@classmethod
def from_timm(
cls,
model_id: Union[str, Path],
export: bool = False,
force_download: bool = False,
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
subfolder: str = "",
config: Optional["PretrainedConfig"] = None,
local_files_only: bool = False,
trust_remote_code: bool = False,
revision: Optional[str] = None,
**kwargs,
) -> "OptimizedModel":
"""
Returns:
`OptimizedModel`: The loaded optimized model.
"""
if isinstance(model_id, Path):
model_id = model_id.as_posix()

from_transformers = kwargs.pop("from_transformers", None)
if from_transformers is not None:
logger.warning(
"The argument `from_transformers` is deprecated, and will be removed in optimum 2.0. Use `export` instead"
)
export = from_transformers

if len(model_id.split("@")) == 2:
if revision is not None:
logger.warning(
f"The argument `revision` was set to {revision} but will be ignored for {model_id.split('@')[1]}"
)
model_id, revision = model_id.split("@")

# if config is None:
# if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME:
# if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)):
# config = AutoConfig.from_pretrained(
# os.path.join(model_id, subfolder, CONFIG_NAME), trust_remote_code=trust_remote_code
# )
# elif CONFIG_NAME in os.listdir(model_id):
# config = AutoConfig.from_pretrained(
# os.path.join(model_id, CONFIG_NAME), trust_remote_code=trust_remote_code
# )
# logger.info(
# f"config.json not found in the specified subfolder {subfolder}. Using the top level config.json."
# )
# else:
# raise OSError(f"config.json not found in {model_id} local folder")
# else:
config = cls._load_config(
model_id,
revision=revision,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
force_download=force_download,
subfolder=subfolder,
trust_remote_code=trust_remote_code,
)
# elif isinstance(config, (str, os.PathLike)):
# config = cls._load_config(
# config,
# revision=revision,
# cache_dir=cache_dir,
# use_auth_token=use_auth_token,
# force_download=force_download,
# subfolder=subfolder,
# trust_remote_code=trust_remote_code,
# )

if not export and trust_remote_code:
logger.warning(
"The argument `trust_remote_code` is to be used along with export=True. It will be ignored."
)
elif export and trust_remote_code is None:
trust_remote_code = False

# from_pretrained_method = cls._from_transformers if export else cls._from_pretrained
return cls._from_transformers(
model_id=model_id,
config=config,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
use_auth_token=use_auth_token,
subfolder=subfolder,
local_files_only=local_files_only,
trust_remote_code=trust_remote_code,
**kwargs,
)

def forward(
self,
pixel_values: Union[torch.Tensor, np.ndarray],
**kwargs,
):
self.compile()

np_inputs = isinstance(pixel_values, np.ndarray)
if not np_inputs:
pixel_values = np.array(pixel_values)

inputs = {
"pixel_values": pixel_values,
}

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
return ImageClassifierOutput(logits=logits)

@classmethod
def _load_config(cls, model_id, **kwargs):
return TimmConfig.from_pretrained(model_id, **kwargs)


AUDIO_CLASSIFICATION_EXAMPLE = r"""
Example of audio classification using `transformers.pipelines`:
```python
Expand Down
5 changes: 1 addition & 4 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,6 @@ def _from_pretrained(
cls.config_name,
}
)
ignore_patterns = ["*.msgpack", "*.safetensors", "*pytorch_model.bin"]
if not from_onnx:
ignore_patterns.extend(["*.onnx", "*.onnx_data"])
# Downloads all repo's files matching the allowed patterns
model_id = snapshot_download(
model_id,
Expand All @@ -228,7 +225,7 @@ def _from_pretrained(
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
ignore_patterns=["*.msgpack", "*.safetensors", "*pytorch_model.bin"],
)
new_model_save_dir = Path(model_id)

Expand Down
Loading