Skip to content

Commit 408262b

Browse files
gonnetcopybara-github
authored andcommitted
Add and use a hidden mapping of qtyping.BufferT data to buffer IDs to a qtyping.ModelTto make identifying duplicate buffers easier.
PiperOrigin-RevId: 895934926
1 parent 87947f1 commit 408262b

17 files changed

+207
-190
lines changed

ai_edge_quantizer/transformation_performer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,7 @@ def _apply_single_transformation(
247247
trans_info = self._transformation_registration[instruction.transformation](
248248
transformation_utils.TransformationInput(
249249
instruction.tensor_id,
250-
tflite_model.operatorCodes,
251-
tflite_model.buffers,
250+
tflite_model,
252251
tflite_model.subgraphs[transformation_inst.subgraph_id],
253252
producer,
254253
consumers,

ai_edge_quantizer/transformations/dequant_insert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def insert_dequant(
4141
"""
4242
dequant_op_code_idx = transformation_utils.add_op_code(
4343
qtyping.BuiltinOperator.DEQUANTIZE,
44-
transformation_input.op_codes,
44+
transformation_input.model.operatorCodes,
4545
)
4646
# create output tensor for the dequant op
4747
tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]

ai_edge_quantizer/transformations/dequant_insert_test.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ def test_dequant_insert_constant(self):
4747
# insert dequant on the constant before the add node
4848
dequant_insert.insert_dequant(
4949
transformation_utils.TransformationInput(
50-
7,
51-
model.operatorCodes,
52-
model.buffers,
53-
subgraph,
54-
-1,
55-
[4],
56-
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
50+
tensor_id=7,
51+
model=model,
52+
subgraph=subgraph,
53+
producer=-1,
54+
consumers=[4],
55+
quant_params=qtyping.UniformQuantParams(
56+
8, None, np.array([1]), np.array([0])
57+
),
5758
)
5859
)
5960

@@ -87,13 +88,14 @@ def test_dequant_insert_activation(self):
8788
# insert dequant on the output of a conv node
8889
dequant_insert.insert_dequant(
8990
transformation_utils.TransformationInput(
90-
4,
91-
model.operatorCodes,
92-
model.buffers,
93-
subgraph,
94-
1,
95-
[3],
96-
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
91+
tensor_id=4,
92+
model=model,
93+
subgraph=subgraph,
94+
producer=1,
95+
consumers=[3],
96+
quant_params=qtyping.UniformQuantParams(
97+
8, None, np.array([1]), np.array([0])
98+
),
9799
)
98100
)
99101

@@ -129,13 +131,14 @@ def test_dequant_insert_constant_multiple_consumers(self):
129131
# insert dequant on the input of a conv node
130132
post_trans_info = dequant_insert.insert_dequant(
131133
transformation_utils.TransformationInput(
132-
2,
133-
model.operatorCodes,
134-
model.buffers,
135-
subgraph,
136-
-1,
137-
[1, 2],
138-
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
134+
tensor_id=2,
135+
model=model,
136+
subgraph=subgraph,
137+
producer=-1,
138+
consumers=[1, 2],
139+
quant_params=qtyping.UniformQuantParams(
140+
8, None, np.array([1]), np.array([0])
141+
),
139142
)
140143
)
141144
self.assertEqual(post_trans_info.op_id, 1)
@@ -173,13 +176,14 @@ def test_dequant_insert_activation_multiple_consumers(self):
173176
# insert dequant on the output of a conv node
174177
dequant_insert.insert_dequant(
175178
transformation_utils.TransformationInput(
176-
1,
177-
model.operatorCodes,
178-
model.buffers,
179-
subgraph,
180-
0,
181-
[1, 2],
182-
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
179+
tensor_id=1,
180+
model=model,
181+
subgraph=subgraph,
182+
producer=0,
183+
consumers=[1, 2],
184+
quant_params=qtyping.UniformQuantParams(
185+
8, None, np.array([1]), np.array([0])
186+
),
183187
)
184188
)
185189

@@ -215,13 +219,14 @@ def test_dequant_insert_activation_multiple_consumers_select(self):
215219
# insert dequant on the output of a conv node
216220
dequant_insert.insert_dequant(
217221
transformation_utils.TransformationInput(
218-
1,
219-
model.operatorCodes,
220-
model.buffers,
221-
subgraph,
222-
0,
223-
[1],
224-
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
222+
tensor_id=1,
223+
model=model,
224+
subgraph=subgraph,
225+
producer=0,
226+
consumers=[1],
227+
quant_params=qtyping.UniformQuantParams(
228+
8, None, np.array([1]), np.array([0])
229+
),
225230
)
226231
)
227232

@@ -257,13 +262,14 @@ def test_dequant_insert_on_graph_output(self):
257262
# insert dequant on the graph output
258263
dequant_insert.insert_dequant(
259264
transformation_utils.TransformationInput(
260-
8,
261-
model.operatorCodes,
262-
model.buffers,
263-
subgraph,
264-
4,
265-
[-1],
266-
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
265+
tensor_id=8,
266+
model=model,
267+
subgraph=subgraph,
268+
producer=4,
269+
consumers=[-1],
270+
quant_params=qtyping.UniformQuantParams(
271+
8, None, np.array([1]), np.array([0])
272+
),
267273
)
268274
)
269275

ai_edge_quantizer/transformations/duplicate_buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def duplicate_buffer(
2626
"""Duplicates the buffer of the tensor."""
2727
tensor_id = transformation_input.tensor_id
2828
tensor = transformation_input.subgraph.tensors[tensor_id]
29-
buffer_data = transformation_input.buffers[tensor.buffer].data
29+
buffer_data = transformation_input.model.buffers[tensor.buffer].data
3030
if buffer_data is None:
3131
tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
3232
raise ValueError(
@@ -36,7 +36,7 @@ def duplicate_buffer(
3636

3737
duplicated_buffer_id = transformation_utils.get_constant_buffer(
3838
data=buffer_data,
39-
buffers=transformation_input.buffers,
39+
model=transformation_input.model,
4040
force_duplicate_buffer=True,
4141
)
4242
tensor.buffer = duplicated_buffer_id

ai_edge_quantizer/transformations/duplicate_buffer_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# ==============================================================================
1515

1616
import pathlib
17+
18+
from absl.testing import absltest
1719
import numpy as np
18-
import absl.testing.absltest as absltest
20+
1921
from ai_edge_quantizer import qtyping
2022
from ai_edge_quantizer.transformations import duplicate_buffer
2123
from ai_edge_quantizer.transformations import transformation_utils
@@ -40,9 +42,8 @@ def _get_transformation_input(
4042
) -> transformation_utils.TransformationInput:
4143
return transformation_utils.TransformationInput(
4244
tensor_id=tensor_idx,
43-
buffers=self.model.buffers,
45+
model=self.model,
4446
# Dummy params below.
45-
op_codes=self.model.operatorCodes,
4647
subgraph=self.model.subgraphs[subgraph_idx],
4748
producer=-1,
4849
consumers=[],

ai_edge_quantizer/transformations/duplicate_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def duplicate_tensor(
2828
subgraph = transformation_input.subgraph
2929
tensor = subgraph.tensors[tensor_id]
3030
tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
31-
buffer_data = transformation_input.buffers[tensor.buffer].data
31+
buffer_data = transformation_input.model.buffers[tensor.buffer].data
3232
if buffer_data is None:
3333
raise ValueError(
3434
'Duplicate Tensor transformation supports only constant tensors.'
@@ -40,7 +40,7 @@ def duplicate_tensor(
4040
tensor_type=tensor.type,
4141
tensor_shape=tensor.shape,
4242
subgraph=subgraph,
43-
buffers=transformation_input.buffers,
43+
model=transformation_input.model,
4444
force_duplicate_buffer=True,
4545
)
4646
# Update the tensor name to avoid name collision in case when tensor is

ai_edge_quantizer/transformations/duplicate_tensor_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# ==============================================================================
1515

1616
import pathlib
17+
18+
from absl.testing import absltest
1719
import numpy as np
18-
import absl.testing.absltest as absltest
20+
1921
from ai_edge_quantizer import qtyping
2022
from ai_edge_quantizer.transformations import duplicate_tensor
2123
from ai_edge_quantizer.transformations import transformation_utils
@@ -43,10 +45,9 @@ def _get_transformation_input(
4345
) -> transformation_utils.TransformationInput:
4446
return transformation_utils.TransformationInput(
4547
tensor_id=tensor_idx,
46-
buffers=self.model.buffers,
48+
model=self.model,
4749
consumers=consumers,
4850
# Dummy params below.
49-
op_codes=self.model.operatorCodes,
5051
subgraph=self.model.subgraphs[subgraph_idx],
5152
producer=-1,
5253
quant_params=qtyping.UniformQuantParams(

ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def insert_decomposed_hadamard_rotation(
165165
np.array(prerotate_shape, dtype=np.int32),
166166
qtyping.TensorType.INT32,
167167
transformation_input.subgraph,
168-
transformation_input.buffers,
168+
transformation_input.model,
169169
)
170170
prerotate_reshape_output_tensor_id = (
171171
transformation_utils.add_new_activation_tensor(
@@ -178,7 +178,7 @@ def insert_decomposed_hadamard_rotation(
178178

179179
prerotate_reshape_op_code_idx = transformation_utils.add_op_code(
180180
qtyping.BuiltinOperator.RESHAPE,
181-
transformation_input.op_codes,
181+
transformation_input.model.operatorCodes,
182182
'RESHAPE',
183183
)
184184
prerorate_reshape_op = qtyping.OperatorT()
@@ -201,7 +201,7 @@ def insert_decomposed_hadamard_rotation(
201201
),
202202
tensor_type=qtyping.TensorType.INT4,
203203
subgraph=transformation_input.subgraph,
204-
buffers=transformation_input.buffers,
204+
model=transformation_input.model,
205205
tensor_shape=hadamard_matrix.shape,
206206
quantization=qtyping.QuantizationParametersT(
207207
scale=np.array([1.0 / np.sqrt(hadamard_size)], dtype=np.float32),
@@ -219,7 +219,7 @@ def insert_decomposed_hadamard_rotation(
219219

220220
fc_op_code_idx = transformation_utils.add_op_code(
221221
qtyping.BuiltinOperator.FULLY_CONNECTED,
222-
transformation_input.op_codes,
222+
transformation_input.model.operatorCodes,
223223
'FULLY_CONNECTED',
224224
)
225225
fc_op = qtyping.OperatorT()
@@ -234,7 +234,7 @@ def insert_decomposed_hadamard_rotation(
234234
# Insert x' = tfl.reshape(x', x.shape)
235235
post_reshape_op_code_idx = transformation_utils.add_op_code(
236236
qtyping.BuiltinOperator.RESHAPE,
237-
transformation_input.op_codes,
237+
transformation_input.model.operatorCodes,
238238
'RESHAPE',
239239
)
240240
post_reshape_op = qtyping.OperatorT()
@@ -244,7 +244,7 @@ def insert_decomposed_hadamard_rotation(
244244
np.array(tensor.shape, dtype=np.int32),
245245
qtyping.TensorType.INT32,
246246
transformation_input.subgraph,
247-
transformation_input.buffers,
247+
transformation_input.model,
248248
)
249249

250250
post_reshape_output_tensor_id = (

ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ def test_raise_unsupported_qparams(self):
5656
insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
5757
transformation_utils.TransformationInput(
5858
tensor_id=0,
59-
op_codes=self.model.operatorCodes,
60-
buffers=self.model.buffers,
59+
model=self.model,
6160
subgraph=self.model.subgraphs[0],
6261
producer=-1,
6362
consumers=[-1],
@@ -74,8 +73,7 @@ def test_raise_missing_hadamard_data(self):
7473
insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
7574
transformation_utils.TransformationInput(
7675
tensor_id=0,
77-
op_codes=self.model.operatorCodes,
78-
buffers=self.model.buffers,
76+
model=self.model,
7977
subgraph=self.model.subgraphs[0],
8078
producer=-1,
8179
consumers=[-1],
@@ -96,8 +94,7 @@ def test_raise_non_float32_tensor(self):
9694
insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
9795
transformation_utils.TransformationInput(
9896
tensor_id=0,
99-
op_codes=self.model.operatorCodes,
100-
buffers=self.model.buffers,
97+
model=self.model,
10198
subgraph=self.model.subgraphs[0],
10299
producer=-1,
103100
consumers=[-1],
@@ -111,8 +108,7 @@ def test_insert_decomposed_ops(self):
111108
insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
112109
transformation_utils.TransformationInput(
113110
tensor_id=0,
114-
op_codes=self.model.operatorCodes,
115-
buffers=self.model.buffers,
111+
model=self.model,
116112
subgraph=self.model.subgraphs[0],
117113
producer=-1,
118114
consumers=[0], # Consumer is the FC op
@@ -201,8 +197,7 @@ def test_insert_decomposed_ops(self):
201197
insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
202198
transformation_utils.TransformationInput(
203199
tensor_id=2, # Output of embedding_lookup
204-
op_codes=self.model.operatorCodes,
205-
buffers=self.model.buffers,
200+
model=self.model,
206201
subgraph=self.model.subgraphs[0],
207202
producer=0,
208203
consumers=[-1], # Output is a graph output

ai_edge_quantizer/transformations/insert_hadamard_rotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def insert_hadamard_rotation(
129129
# tensor as output.
130130
custom_op_code_idx = transformation_utils.add_op_code(
131131
qtyping.BuiltinOperator.CUSTOM,
132-
transformation_input.op_codes,
132+
transformation_input.model.operatorCodes,
133133
'aeq.hadamard_rotation',
134134
)
135135
custom_op = qtyping.OperatorT()

0 commit comments

Comments
 (0)