Skip to content

Commit 350c8f9

Browse files
committed
fix loading seq2seq models
1 parent 04b2898 commit 350c8f9

1 file changed

Lines changed: 47 additions & 101 deletions

File tree

optimum/onnxruntime/modeling_seq2seq.py

Lines changed: 47 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ORTModelForXXX classes related to seq2seq, allowing to run ONNX Models with ONNX Runtime using the same API as Transformers.
1616
"""
1717

18+
import re
1819
from pathlib import Path
1920
from tempfile import TemporaryDirectory
2021
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple, Union
@@ -42,7 +43,7 @@
4243
from ..exporters.onnx import main_export
4344
from ..exporters.tasks import TasksManager
4445
from ..utils import NormalizedConfigManager, is_transformers_version
45-
from ..utils.file_utils import validate_file_exists
46+
from ..utils.file_utils import find_files_matching_pattern
4647
from ..utils.logging import get_logger, warn_once
4748
from ..utils.save_utils import maybe_save_preprocessors
4849
from .base import ORTParentMixin, ORTSessionMixin
@@ -51,6 +52,7 @@
5152
DECODER_ONNX_FILE_PATTERN,
5253
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
5354
ENCODER_ONNX_FILE_PATTERN,
55+
ONNX_FILE_PATTERN,
5456
)
5557
from .modeling_ort import ORTModel
5658
from .utils import (
@@ -1067,8 +1069,6 @@ def _from_pretrained(
10671069
# other arguments
10681070
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
10691071
):
1070-
model_path = Path(model_id)
1071-
10721072
# We do not implement the logic for use_cache=False, use_merged=True
10731073
if use_cache is False:
10741074
if use_merged is True:
@@ -1078,125 +1078,70 @@ def _from_pretrained(
10781078
)
10791079
use_merged = False
10801080

1081-
decoder_merged_path = None
1082-
# We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it),
1083-
# and use_merged = True (explicitely specified by the user)
1084-
if use_merged is not False:
1085-
try:
1086-
decoder_merged_path = ORTModel._infer_onnx_filename(
1087-
model_id,
1088-
[DECODER_MERGED_ONNX_FILE_PATTERN],
1089-
argument_name=None,
1090-
subfolder=subfolder,
1091-
token=token,
1092-
revision=revision,
1093-
)
1094-
use_merged = True
1095-
decoder_path = decoder_merged_path
1096-
except FileNotFoundError as e:
1097-
if use_merged is True:
1098-
raise FileNotFoundError(
1099-
"The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()"
1100-
" but no ONNX file for a merged decoder could be found in"
1101-
f" {str(Path(model_id, subfolder))}, with the error: {e}"
1102-
)
1103-
use_merged = False
1081+
model_path = Path(model_id)
1082+
1083+
onnx_files = find_files_matching_pattern(
1084+
model_id,
1085+
ONNX_FILE_PATTERN,
1086+
glob_pattern="**/*.onnx",
1087+
subfolder=subfolder,
1088+
token=token,
1089+
revision=revision,
1090+
)
11041091

1092+
if len(onnx_files) == 0:
1093+
raise FileNotFoundError(f"Could not find any ONNX model file in {model_id}")
1094+
1095+
decoder_merged_path = None
11051096
decoder_without_past_path = None
11061097
decoder_with_past_path = None
1107-
if use_merged is False:
1108-
if not validate_file_exists(
1109-
model_id, decoder_file_name, subfolder=subfolder, revision=revision, token=token
1110-
):
1111-
decoder_without_past_path = ORTModel._infer_onnx_filename(
1112-
model_id,
1113-
[DECODER_ONNX_FILE_PATTERN],
1114-
"decoder_file_name",
1115-
subfolder=subfolder,
1116-
token=token,
1117-
revision=revision,
1118-
)
1119-
else:
1120-
decoder_without_past_path = model_path / subfolder / decoder_file_name
11211098

1122-
decoder_path = decoder_without_past_path
1099+
model_files = []
1100+
# Check first for merged models and then for decoder / decoder_with_past models
1101+
if use_merged is not False:
1102+
model_files = [p for p in onnx_files if re.search(DECODER_MERGED_ONNX_FILE_PATTERN, str(p))]
1103+
use_merged = len(model_files) != 0
11231104

1124-
decoder_regular_onnx_filenames = ORTModel._generate_regular_names_for_filename(ONNX_DECODER_NAME)
1125-
if decoder_path.name not in decoder_regular_onnx_filenames:
1126-
logger.warning(
1127-
f"The ONNX file {decoder_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the "
1128-
f"{cls.__name__} might not behave as expected."
1105+
if use_merged is False:
1106+
pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN
1107+
model_files = [p for p in onnx_files if re.search(pattern, str(p))]
1108+
if use_cache:
1109+
decoder_with_past_path = [file for file in model_files if file.name == decoder_with_past_file_name]
1110+
decoder_with_past_path = decoder_with_past_path[0] if decoder_with_past_path else model_files[0]
1111+
decoder_without_past_path = decoder_without_past_path.parent / decoder_without_past_path.name.replace(
1112+
"_with_past", ""
11291113
)
1130-
1131-
# If the decoder without / with past has been merged, we do not need to look for any additional file
1132-
if use_cache is True and use_merged is False:
1133-
if not validate_file_exists(
1134-
model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision, token=token
1135-
):
1136-
try:
1137-
decoder_with_past_path = ORTModel._infer_onnx_filename(
1138-
model_id,
1139-
[DECODER_WITH_PAST_ONNX_FILE_PATTERN],
1140-
"decoder_with_past_file_name",
1141-
subfolder=subfolder,
1142-
token=token,
1143-
revision=revision,
1144-
)
1145-
except FileNotFoundError as e:
1146-
raise FileNotFoundError(
1147-
"The parameter `use_cache=True` was passed to ORTModelForCausalLM.from_pretrained()"
1148-
" but no ONNX file using past key values could be found in"
1149-
f" {str(Path(model_id, subfolder))}, with the error: {e}"
1150-
)
1151-
else:
1152-
decoder_with_past_path = model_path / subfolder / decoder_with_past_file_name
1153-
1154-
decoder_path = decoder_without_past_path
1155-
1156-
decoder_with_past_regular_onnx_filenames = ORTModel._generate_regular_names_for_filename(
1157-
ONNX_DECODER_WITH_PAST_NAME
1114+
else:
1115+
decoder_without_past_path = [file for file in model_files if file.name == decoder_file_name]
1116+
decoder_without_past_path = (
1117+
decoder_without_past_path[0] if decoder_without_past_path else model_files[0]
11581118
)
1159-
1160-
if decoder_with_past_path.name not in decoder_with_past_regular_onnx_filenames:
1161-
logger.warning(
1162-
f"The ONNX file {decoder_with_past_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, "
1163-
f"the {cls.__name__} might not behave as expected."
1164-
)
1165-
1166-
if not validate_file_exists(model_id, encoder_file_name, subfolder=subfolder, revision=revision, token=token):
1167-
encoder_path = ORTModel._infer_onnx_filename(
1168-
model_id,
1169-
[ENCODER_ONNX_FILE_PATTERN],
1170-
"encoder_file_name",
1171-
subfolder=subfolder,
1172-
token=token,
1173-
revision=revision,
1174-
)
11751119
else:
1176-
encoder_path = model_path / subfolder / encoder_file_name
1120+
decoder_merged_path = model_files[0]
11771121

1178-
encoder_regular_onnx_filenames = ORTModel._generate_regular_names_for_filename(ONNX_ENCODER_NAME)
1179-
if encoder_path.name not in encoder_regular_onnx_filenames:
1180-
logger.warning(
1181-
f"The ONNX file {encoder_path.name} is not a regular name used in optimum.onnxruntime, the "
1182-
"ORTModelForConditionalGeneration might not behave as expected."
1183-
)
1122+
model_files = [p for p in onnx_files if re.search(ENCODER_ONNX_FILE_PATTERN, str(p))]
1123+
encoder_path = [file for file in model_files if file.name == encoder_file_name]
1124+
encoder_path = encoder_path[0] if encoder_path else model_files[0]
11841125

11851126
if model_path.is_dir():
11861127
new_model_save_dir = model_path
11871128
else:
11881129
attribute_name_to_filename = {
1189-
"last_encoder_model_name": encoder_path.name,
1190-
"last_decoder_model_name": decoder_path.name if use_merged is False else None,
1130+
"last_encoder_model_name": encoder_path,
1131+
"last_decoder_model_name": decoder_path if use_merged is False else None,
11911132
"last_decoder_with_past_model_name": (
1192-
decoder_with_past_path.name if (use_merged is False and use_cache is True) else None
1133+
decoder_with_past_path if (use_merged is False and use_cache is True) else None
11931134
),
1194-
"last_decoder_merged_name": decoder_merged_path.name if use_merged is True else None,
1135+
"last_decoder_merged_name": decoder_merged_path if use_merged is True else None,
11951136
}
11961137
paths = {}
11971138
for attr_name, filename in attribute_name_to_filename.items():
11981139
if filename is None:
11991140
continue
1141+
1142+
subfolder = filename.parent.as_posix()
1143+
filename = filename.name
1144+
12001145
model_cache_path = cached_file(
12011146
model_id,
12021147
filename=filename,
@@ -1223,6 +1168,7 @@ def _from_pretrained(
12231168
pass
12241169

12251170
paths[attr_name] = Path(model_cache_path).name
1171+
12261172
new_model_save_dir = Path(model_cache_path).parent
12271173

12281174
if use_merged is True:

0 commit comments

Comments
 (0)