Skip to content

Commit da5200b

Browse files
authored
Add compatibility with transformers 4.52 (#2270)
* compatibility with transformers 4.52 * fix * add back default * update setup * fix * fix perceiver * fix perceiver test * fix style * add min max tokens pipeline test * fix * comment * upgrade runner * increase batch size for test * fix test * style * update model * fix * add model * add fix * to rm * revert * style * run test in serie * fix tests * fix * tmp * update model * trigger test * revert * fix * fix default model id for pipelines * style * fix * fix * style * fix loading seq2seq models * add test * style * apply comments * remove unused * rename * style * add more tests * add test * fix * style
1 parent e15053d commit da5200b

7 files changed

Lines changed: 219 additions & 183 deletions

File tree

.github/workflows/test_common.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
python-version: [3.9]
22-
runs-on: [ubuntu-22.04, windows-2019, macos-13]
22+
runs-on: [ubuntu-22.04, windows-2019, macos-14]
2323

2424
runs-on: ${{ matrix.runs-on }}
2525

optimum/onnxruntime/modeling_seq2seq.py

Lines changed: 44 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ORTModelForXXX classes related to seq2seq, allowing to run ONNX Models with ONNX Runtime using the same API as Transformers.
1616
"""
1717

18+
import re
1819
from pathlib import Path
1920
from tempfile import TemporaryDirectory
2021
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple, Union
@@ -42,7 +43,7 @@
4243
from ..exporters.onnx import main_export
4344
from ..exporters.tasks import TasksManager
4445
from ..utils import NormalizedConfigManager, is_transformers_version
45-
from ..utils.file_utils import validate_file_exists
46+
from ..utils.file_utils import find_files_matching_pattern
4647
from ..utils.logging import get_logger, warn_once
4748
from ..utils.save_utils import maybe_save_preprocessors
4849
from .base import ORTParentMixin, ORTSessionMixin
@@ -51,6 +52,7 @@
5152
DECODER_ONNX_FILE_PATTERN,
5253
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
5354
ENCODER_ONNX_FILE_PATTERN,
55+
ONNX_FILE_PATTERN,
5456
)
5557
from .modeling_ort import ORTModel
5658
from .utils import (
@@ -1067,8 +1069,6 @@ def _from_pretrained(
10671069
# other arguments
10681070
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
10691071
):
1070-
model_path = Path(model_id)
1071-
10721072
# We do not implement the logic for use_cache=False, use_merged=True
10731073
if use_cache is False:
10741074
if use_merged is True:
@@ -1078,125 +1078,61 @@ def _from_pretrained(
10781078
)
10791079
use_merged = False
10801080

1081-
decoder_merged_path = None
1082-
# We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it),
1083-
# and use_merged = True (explicitely specified by the user)
1084-
if use_merged is not False:
1085-
try:
1086-
decoder_merged_path = ORTModel._infer_onnx_filename(
1087-
model_id,
1088-
[DECODER_MERGED_ONNX_FILE_PATTERN],
1089-
argument_name=None,
1090-
subfolder=subfolder,
1091-
token=token,
1092-
revision=revision,
1093-
)
1094-
use_merged = True
1095-
decoder_path = decoder_merged_path
1096-
except FileNotFoundError as e:
1097-
if use_merged is True:
1098-
raise FileNotFoundError(
1099-
"The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()"
1100-
" but no ONNX file for a merged decoder could be found in"
1101-
f" {str(Path(model_id, subfolder))}, with the error: {e}"
1102-
)
1103-
use_merged = False
1104-
1105-
decoder_without_past_path = None
1106-
decoder_with_past_path = None
1107-
if use_merged is False:
1108-
if not validate_file_exists(
1109-
model_id, decoder_file_name, subfolder=subfolder, revision=revision, token=token
1110-
):
1111-
decoder_without_past_path = ORTModel._infer_onnx_filename(
1112-
model_id,
1113-
[DECODER_ONNX_FILE_PATTERN],
1114-
"decoder_file_name",
1115-
subfolder=subfolder,
1116-
token=token,
1117-
revision=revision,
1118-
)
1119-
else:
1120-
decoder_without_past_path = model_path / subfolder / decoder_file_name
1121-
1122-
decoder_path = decoder_without_past_path
1123-
1124-
decoder_regular_onnx_filenames = ORTModel._generate_regular_names_for_filename(ONNX_DECODER_NAME)
1125-
if decoder_path.name not in decoder_regular_onnx_filenames:
1126-
logger.warning(
1127-
f"The ONNX file {decoder_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the "
1128-
f"{cls.__name__} might not behave as expected."
1129-
)
1081+
model_path = Path(model_id)
11301082

1131-
# If the decoder without / with past has been merged, we do not need to look for any additional file
1132-
if use_cache is True and use_merged is False:
1133-
if not validate_file_exists(
1134-
model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision, token=token
1135-
):
1136-
try:
1137-
decoder_with_past_path = ORTModel._infer_onnx_filename(
1138-
model_id,
1139-
[DECODER_WITH_PAST_ONNX_FILE_PATTERN],
1140-
"decoder_with_past_file_name",
1141-
subfolder=subfolder,
1142-
token=token,
1143-
revision=revision,
1144-
)
1145-
except FileNotFoundError as e:
1146-
raise FileNotFoundError(
1147-
"The parameter `use_cache=True` was passed to ORTModelForCausalLM.from_pretrained()"
1148-
" but no ONNX file using past key values could be found in"
1149-
f" {str(Path(model_id, subfolder))}, with the error: {e}"
1150-
)
1151-
else:
1152-
decoder_with_past_path = model_path / subfolder / decoder_with_past_file_name
1083+
onnx_files = find_files_matching_pattern(
1084+
model_id,
1085+
ONNX_FILE_PATTERN,
1086+
glob_pattern="**/*.onnx",
1087+
subfolder=subfolder,
1088+
token=token,
1089+
revision=revision,
1090+
)
11531091

1154-
decoder_path = decoder_without_past_path
1092+
if len(onnx_files) == 0:
1093+
raise FileNotFoundError(f"Could not find any ONNX model file in {model_id}")
11551094

1156-
decoder_with_past_regular_onnx_filenames = ORTModel._generate_regular_names_for_filename(
1157-
ONNX_DECODER_WITH_PAST_NAME
1158-
)
1095+
decoder_path = None
1096+
decoder_with_past_path = None
1097+
# Check first for merged models and then for decoder / decoder_with_past models
1098+
if use_merged is not False:
1099+
model_files = [p for p in onnx_files if re.search(DECODER_MERGED_ONNX_FILE_PATTERN, str(p))]
1100+
use_merged = len(model_files) != 0
11591101

1160-
if decoder_with_past_path.name not in decoder_with_past_regular_onnx_filenames:
1161-
logger.warning(
1162-
f"The ONNX file {decoder_with_past_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, "
1163-
f"the {cls.__name__} might not behave as expected."
1164-
)
1102+
if use_merged is False:
1103+
pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN
1104+
model_files = [p for p in onnx_files if re.search(pattern, str(p))]
11651105

1166-
if not validate_file_exists(model_id, encoder_file_name, subfolder=subfolder, revision=revision, token=token):
1167-
encoder_path = ORTModel._infer_onnx_filename(
1168-
model_id,
1169-
[ENCODER_ONNX_FILE_PATTERN],
1170-
"encoder_file_name",
1171-
subfolder=subfolder,
1172-
token=token,
1173-
revision=revision,
1174-
)
1106+
if use_cache:
1107+
decoder_with_past_path = [file for file in model_files if file.name == decoder_with_past_file_name]
1108+
decoder_with_past_path = decoder_with_past_path[0] if decoder_with_past_path else model_files[0]
1109+
decoder_path = decoder_with_past_path.parent / decoder_with_past_path.name.replace("_with_past", "")
1110+
else:
1111+
decoder_path = [file for file in model_files if file.name == decoder_file_name]
1112+
decoder_path = decoder_path[0] if decoder_path else model_files[0]
11751113
else:
1176-
encoder_path = model_path / subfolder / encoder_file_name
1114+
decoder_path = model_files[0]
11771115

1178-
encoder_regular_onnx_filenames = ORTModel._generate_regular_names_for_filename(ONNX_ENCODER_NAME)
1179-
if encoder_path.name not in encoder_regular_onnx_filenames:
1180-
logger.warning(
1181-
f"The ONNX file {encoder_path.name} is not a regular name used in optimum.onnxruntime, the "
1182-
"ORTModelForConditionalGeneration might not behave as expected."
1183-
)
1116+
model_files = [p for p in onnx_files if re.search(ENCODER_ONNX_FILE_PATTERN, str(p))]
1117+
encoder_path = [file for file in model_files if file.name == encoder_file_name]
1118+
encoder_path = encoder_path[0] if encoder_path else model_files[0]
11841119

11851120
if model_path.is_dir():
11861121
new_model_save_dir = model_path
11871122
else:
11881123
attribute_name_to_filename = {
1189-
"last_encoder_model_name": encoder_path.name,
1190-
"last_decoder_model_name": decoder_path.name if use_merged is False else None,
1191-
"last_decoder_with_past_model_name": (
1192-
decoder_with_past_path.name if (use_merged is False and use_cache is True) else None
1193-
),
1194-
"last_decoder_merged_name": decoder_merged_path.name if use_merged is True else None,
1124+
"last_encoder_model_name": encoder_path,
1125+
"last_decoder_model_name": decoder_path if not use_merged else None,
1126+
"last_decoder_with_past_model_name": decoder_with_past_path if not use_merged and use_cache else None,
1127+
"last_decoder_merged_name": decoder_path if use_merged else None,
11951128
}
11961129
paths = {}
11971130
for attr_name, filename in attribute_name_to_filename.items():
11981131
if filename is None:
11991132
continue
1133+
1134+
subfolder = filename.parent.as_posix()
1135+
filename = filename.name
12001136
model_cache_path = cached_file(
12011137
model_id,
12021138
filename=filename,
@@ -1223,16 +1159,15 @@ def _from_pretrained(
12231159
pass
12241160

12251161
paths[attr_name] = Path(model_cache_path).name
1162+
12261163
new_model_save_dir = Path(model_cache_path).parent
12271164

1228-
if use_merged is True:
1165+
if use_merged:
12291166
decoder_path = new_model_save_dir / paths["last_decoder_merged_name"]
1230-
decoder_merged_path = new_model_save_dir / paths["last_decoder_merged_name"]
12311167
else:
12321168
decoder_path = new_model_save_dir / paths["last_decoder_model_name"]
1233-
decoder_without_past_path = new_model_save_dir / paths["last_decoder_model_name"]
12341169

1235-
if use_cache is True:
1170+
if use_cache:
12361171
decoder_with_past_path = new_model_save_dir / paths["last_decoder_with_past_model_name"]
12371172

12381173
encoder_path = new_model_save_dir / paths["last_encoder_model_name"]

0 commit comments

Comments
 (0)