2121from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple , Union
2222
2323import onnx
24- from datasets import Dataset , load_dataset
2524from packaging .version import Version , parse
2625from transformers import AutoConfig
2726
2827from onnxruntime import __version__ as ort_version
2928from onnxruntime .quantization import CalibrationDataReader , QuantFormat , QuantizationMode , QuantType
3029from onnxruntime .quantization .onnx_quantizer import ONNXQuantizer
3130from onnxruntime .quantization .qdq_quantizer import QDQQuantizer
31+ from optimum .utils .import_utils import requires_backends
3232
3333from ..quantization_base import OptimumQuantizer
3434from ..utils .save_utils import maybe_save_preprocessors
4040
4141
4242if TYPE_CHECKING :
43+ from datasets import Dataset
4344 from transformers import PretrainedConfig
4445
4546LOGGER = logging .getLogger (__name__ )
4849class ORTCalibrationDataReader (CalibrationDataReader ):
4950 __slots__ = ["batch_size" , "dataset" , "_dataset_iter" ]
5051
51- def __init__ (self , dataset : Dataset , batch_size : int = 1 ):
52+ def __init__ (self , dataset : " Dataset" , batch_size : int = 1 ):
5253 if dataset is None :
5354 raise ValueError ("Provided dataset is None." )
5455
@@ -158,7 +159,7 @@ def from_pretrained(
158159
159160 def fit (
160161 self ,
161- dataset : Dataset ,
162+ dataset : " Dataset" ,
162163 calibration_config : CalibrationConfig ,
163164 onnx_augmented_model_name : Union [str , Path ] = "augmented_model.onnx" ,
164165 operators_to_quantize : Optional [List [str ]] = None ,
@@ -212,7 +213,7 @@ def fit(
212213
213214 def partial_fit (
214215 self ,
215- dataset : Dataset ,
216+ dataset : " Dataset" ,
216217 calibration_config : CalibrationConfig ,
217218 onnx_augmented_model_name : Union [str , Path ] = "augmented_model.onnx" ,
218219 operators_to_quantize : Optional [List [str ]] = None ,
@@ -428,7 +429,7 @@ def get_calibration_dataset(
428429 seed : int = 2016 ,
429430 use_auth_token : Optional [Union [bool , str ]] = None ,
430431 token : Optional [Union [bool , str ]] = None ,
431- ) -> Dataset :
432+ ) -> " Dataset" :
432433 """
433434 Creates the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
434435
@@ -474,6 +475,10 @@ def get_calibration_dataset(
474475 "provided."
475476 )
476477
478+ requires_backends (self , ["datasets" ])
479+
480+ from datasets import load_dataset
481+
477482 calib_dataset = load_dataset (
478483 dataset_name ,
479484 name = dataset_config_name ,
@@ -492,7 +497,7 @@ def get_calibration_dataset(
492497
493498 return self .clean_calibration_dataset (processed_calib_dataset )
494499
495- def clean_calibration_dataset (self , dataset : Dataset ) -> Dataset :
500+ def clean_calibration_dataset (self , dataset : " Dataset" ) -> " Dataset" :
496501 model = onnx .load (self .onnx_model_path )
497502 model_inputs = {input .name for input in model .graph .input }
498503 ignored_columns = list (set (dataset .column_names ) - model_inputs )
0 commit comments