@@ -56,6 +56,7 @@ def convert_pt2(
5656 model : torch .nn .Module ,
5757 inputs : tuple [object , ...],
5858 quantizer : CadenceQuantizer ,
59+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
5960 dump_graphs : bool = False ,
6061) -> torch .fx .GraphModule :
6162 """
@@ -64,6 +65,9 @@ def convert_pt2(
6465 fuse the model later, if applicable. If you do not expect that behavior,
6566 please use quantize_and_fuse_pt2 instead, which will instantiate a
6667 default quantizer for you if needed.
68+ If calibration data is provided, it will be used to calibrate the model. If
69+ not, the inputs will be used for calibration instead, which is useful for
70+ unit tests but should not be used for end-to-end use cases.
6771 Returns a GraphModule with the converted model.
6872 """
6973
@@ -95,7 +99,12 @@ def convert_pt2(
9599 prepared_model = prepare_pt2e (model_gm , quantizer )
96100
97101 # Calibrate
98- prepared_model (* inputs )
102+ # If no calibration data is provided, use the inputs
103+ if calibration_data is None :
104+ calibration_data = [inputs ]
105+
106+ for samples in calibration_data :
107+ prepared_model (* samples )
99108
100109 # Convert
101110 converted_model = convert_pt2e (prepared_model )
@@ -136,10 +145,14 @@ def quantize_pt2(
136145 model : torch .nn .Module ,
137146 inputs : tuple [object , ...],
138147 quantizer : Optional [CadenceQuantizer ] = None ,
148+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
139149 dump_graphs : bool = False ,
140150) -> torch .fx .GraphModule :
141151 """
142152 Prepare, convert and fuse the model using the given quantizer.
153+ If calibration data is provided, it will be used to calibrate the model. If
154+ not, the inputs will be used for calibration instead, which is useful for
155+ unit tests but should not be used for end-to-end use cases.
143156 Returns a GraphModule with the quantized model.
144157 """
145158 # Make the model inference mode by calling model.eval()
@@ -150,7 +163,9 @@ def quantize_pt2(
150163 quantizer = CadenceDefaultQuantizer ()
151164
152165 # Get converted graph module
153- converted_gm = convert_pt2 (model , inputs , quantizer , dump_graphs )
166+ converted_gm = convert_pt2 (
167+ model , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
168+ )
154169
155170 # Get fused model
156171 fused_gm = fuse_pt2 (converted_gm , quantizer )
0 commit comments