1515ORTModelForXXX classes related to seq2seq, allowing to run ONNX Models with ONNX Runtime using the same API as Transformers.
1616"""
1717
18+ import re
1819from pathlib import Path
1920from tempfile import TemporaryDirectory
2021from typing import TYPE_CHECKING , Any , Dict , Optional , Sequence , Set , Tuple , Union
4243from ..exporters .onnx import main_export
4344from ..exporters .tasks import TasksManager
4445from ..utils import NormalizedConfigManager , is_transformers_version
45- from ..utils .file_utils import validate_file_exists
46+ from ..utils .file_utils import find_files_matching_pattern
4647from ..utils .logging import get_logger , warn_once
4748from ..utils .save_utils import maybe_save_preprocessors
4849from .base import ORTParentMixin , ORTSessionMixin
5152 DECODER_ONNX_FILE_PATTERN ,
5253 DECODER_WITH_PAST_ONNX_FILE_PATTERN ,
5354 ENCODER_ONNX_FILE_PATTERN ,
55+ ONNX_FILE_PATTERN ,
5456)
5557from .modeling_ort import ORTModel
5658from .utils import (
@@ -1067,8 +1069,6 @@ def _from_pretrained(
10671069 # other arguments
10681070 model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
10691071 ):
1070- model_path = Path (model_id )
1071-
10721072 # We do not implement the logic for use_cache=False, use_merged=True
10731073 if use_cache is False :
10741074 if use_merged is True :
@@ -1078,125 +1078,61 @@ def _from_pretrained(
10781078 )
10791079 use_merged = False
10801080
1081- decoder_merged_path = None
1082- # We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it),
1083- # and use_merged = True (explicitely specified by the user)
1084- if use_merged is not False :
1085- try :
1086- decoder_merged_path = ORTModel ._infer_onnx_filename (
1087- model_id ,
1088- [DECODER_MERGED_ONNX_FILE_PATTERN ],
1089- argument_name = None ,
1090- subfolder = subfolder ,
1091- token = token ,
1092- revision = revision ,
1093- )
1094- use_merged = True
1095- decoder_path = decoder_merged_path
1096- except FileNotFoundError as e :
1097- if use_merged is True :
1098- raise FileNotFoundError (
1099- "The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()"
1100- " but no ONNX file for a merged decoder could be found in"
1101- f" { str (Path (model_id , subfolder ))} , with the error: { e } "
1102- )
1103- use_merged = False
1104-
1105- decoder_without_past_path = None
1106- decoder_with_past_path = None
1107- if use_merged is False :
1108- if not validate_file_exists (
1109- model_id , decoder_file_name , subfolder = subfolder , revision = revision , token = token
1110- ):
1111- decoder_without_past_path = ORTModel ._infer_onnx_filename (
1112- model_id ,
1113- [DECODER_ONNX_FILE_PATTERN ],
1114- "decoder_file_name" ,
1115- subfolder = subfolder ,
1116- token = token ,
1117- revision = revision ,
1118- )
1119- else :
1120- decoder_without_past_path = model_path / subfolder / decoder_file_name
1121-
1122- decoder_path = decoder_without_past_path
1123-
1124- decoder_regular_onnx_filenames = ORTModel ._generate_regular_names_for_filename (ONNX_DECODER_NAME )
1125- if decoder_path .name not in decoder_regular_onnx_filenames :
1126- logger .warning (
1127- f"The ONNX file { decoder_path .name } is not a regular name used in optimum.onnxruntime that are { decoder_regular_onnx_filenames } , the "
1128- f"{ cls .__name__ } might not behave as expected."
1129- )
1081+ model_path = Path (model_id )
11301082
1131- # If the decoder without / with past has been merged, we do not need to look for any additional file
1132- if use_cache is True and use_merged is False :
1133- if not validate_file_exists (
1134- model_id , decoder_with_past_file_name , subfolder = subfolder , revision = revision , token = token
1135- ):
1136- try :
1137- decoder_with_past_path = ORTModel ._infer_onnx_filename (
1138- model_id ,
1139- [DECODER_WITH_PAST_ONNX_FILE_PATTERN ],
1140- "decoder_with_past_file_name" ,
1141- subfolder = subfolder ,
1142- token = token ,
1143- revision = revision ,
1144- )
1145- except FileNotFoundError as e :
1146- raise FileNotFoundError (
1147- "The parameter `use_cache=True` was passed to ORTModelForCausalLM.from_pretrained()"
1148- " but no ONNX file using past key values could be found in"
1149- f" { str (Path (model_id , subfolder ))} , with the error: { e } "
1150- )
1151- else :
1152- decoder_with_past_path = model_path / subfolder / decoder_with_past_file_name
1083+ onnx_files = find_files_matching_pattern (
1084+ model_id ,
1085+ ONNX_FILE_PATTERN ,
1086+ glob_pattern = "**/*.onnx" ,
1087+ subfolder = subfolder ,
1088+ token = token ,
1089+ revision = revision ,
1090+ )
11531091
1154- decoder_path = decoder_without_past_path
1092+ if len (onnx_files ) == 0 :
1093+ raise FileNotFoundError (f"Could not find any ONNX model file in { model_id } " )
11551094
1156- decoder_with_past_regular_onnx_filenames = ORTModel ._generate_regular_names_for_filename (
1157- ONNX_DECODER_WITH_PAST_NAME
1158- )
1095+ decoder_path = None
1096+ decoder_with_past_path = None
1097+ # Check first for merged models and then for decoder / decoder_with_past models
1098+ if use_merged is not False :
1099+ model_files = [p for p in onnx_files if re .search (DECODER_MERGED_ONNX_FILE_PATTERN , str (p ))]
1100+ use_merged = len (model_files ) != 0
11591101
1160- if decoder_with_past_path .name not in decoder_with_past_regular_onnx_filenames :
1161- logger .warning (
1162- f"The ONNX file { decoder_with_past_path .name } is not a regular name used in optimum.onnxruntime that are { decoder_with_past_regular_onnx_filenames } , "
1163- f"the { cls .__name__ } might not behave as expected."
1164- )
1102+ if use_merged is False :
1103+ pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN
1104+ model_files = [p for p in onnx_files if re .search (pattern , str (p ))]
11651105
1166- if not validate_file_exists (model_id , encoder_file_name , subfolder = subfolder , revision = revision , token = token ):
1167- encoder_path = ORTModel ._infer_onnx_filename (
1168- model_id ,
1169- [ENCODER_ONNX_FILE_PATTERN ],
1170- "encoder_file_name" ,
1171- subfolder = subfolder ,
1172- token = token ,
1173- revision = revision ,
1174- )
1106+ if use_cache :
1107+ decoder_with_past_path = [file for file in model_files if file .name == decoder_with_past_file_name ]
1108+ decoder_with_past_path = decoder_with_past_path [0 ] if decoder_with_past_path else model_files [0 ]
1109+ decoder_path = decoder_with_past_path .parent / decoder_with_past_path .name .replace ("_with_past" , "" )
1110+ else :
1111+ decoder_path = [file for file in model_files if file .name == decoder_file_name ]
1112+ decoder_path = decoder_path [0 ] if decoder_path else model_files [0 ]
11751113 else :
1176- encoder_path = model_path / subfolder / encoder_file_name
1114+ decoder_path = model_files [ 0 ]
11771115
1178- encoder_regular_onnx_filenames = ORTModel ._generate_regular_names_for_filename (ONNX_ENCODER_NAME )
1179- if encoder_path .name not in encoder_regular_onnx_filenames :
1180- logger .warning (
1181- f"The ONNX file { encoder_path .name } is not a regular name used in optimum.onnxruntime, the "
1182- "ORTModelForConditionalGeneration might not behave as expected."
1183- )
1116+ model_files = [p for p in onnx_files if re .search (ENCODER_ONNX_FILE_PATTERN , str (p ))]
1117+ encoder_path = [file for file in model_files if file .name == encoder_file_name ]
1118+ encoder_path = encoder_path [0 ] if encoder_path else model_files [0 ]
11841119
11851120 if model_path .is_dir ():
11861121 new_model_save_dir = model_path
11871122 else :
11881123 attribute_name_to_filename = {
1189- "last_encoder_model_name" : encoder_path .name ,
1190- "last_decoder_model_name" : decoder_path .name if use_merged is False else None ,
1191- "last_decoder_with_past_model_name" : (
1192- decoder_with_past_path .name if (use_merged is False and use_cache is True ) else None
1193- ),
1194- "last_decoder_merged_name" : decoder_merged_path .name if use_merged is True else None ,
1124+ "last_encoder_model_name" : encoder_path ,
1125+ "last_decoder_model_name" : decoder_path if not use_merged else None ,
1126+ "last_decoder_with_past_model_name" : decoder_with_past_path if not use_merged and use_cache else None ,
1127+ "last_decoder_merged_name" : decoder_path if use_merged else None ,
11951128 }
11961129 paths = {}
11971130 for attr_name , filename in attribute_name_to_filename .items ():
11981131 if filename is None :
11991132 continue
1133+
1134+ subfolder = filename .parent .as_posix ()
1135+ filename = filename .name
12001136 model_cache_path = cached_file (
12011137 model_id ,
12021138 filename = filename ,
@@ -1223,16 +1159,15 @@ def _from_pretrained(
12231159 pass
12241160
12251161 paths [attr_name ] = Path (model_cache_path ).name
1162+
12261163 new_model_save_dir = Path (model_cache_path ).parent
12271164
1228- if use_merged is True :
1165+ if use_merged :
12291166 decoder_path = new_model_save_dir / paths ["last_decoder_merged_name" ]
1230- decoder_merged_path = new_model_save_dir / paths ["last_decoder_merged_name" ]
12311167 else :
12321168 decoder_path = new_model_save_dir / paths ["last_decoder_model_name" ]
1233- decoder_without_past_path = new_model_save_dir / paths ["last_decoder_model_name" ]
12341169
1235- if use_cache is True :
1170+ if use_cache :
12361171 decoder_with_past_path = new_model_save_dir / paths ["last_decoder_with_past_model_name" ]
12371172
12381173 encoder_path = new_model_save_dir / paths ["last_encoder_model_name" ]
0 commit comments