@@ -140,6 +140,7 @@ def add_new_constant_tensor(
140140 buffers : list [schema_py_generated .BufferT ],
141141 tensor_shape : Optional [list [int ]] = None ,
142142 force_duplicate_buffer : bool = False ,
143+ quantization : schema_py_generated .QuantizationParametersT | None = None ,
143144) -> int :
144145 """Add a new constant tensor to the model.
145146
@@ -153,6 +154,8 @@ def add_new_constant_tensor(
153154 data will be used.
154155 force_duplicate_buffer: Whether to add a new buffer even if the same buffer
155156 already exists.
157+ quantization: Optional `QuantizationParametersT` describing the quantization
158+ of this tensor.
156159
157160 Returns:
158161 The index of the new tensor in the subgraph.
@@ -166,6 +169,7 @@ def add_new_constant_tensor(
166169 new_tensor .buffer = new_buffer_id
167170 new_tensor .type = tensor_type
168171 new_tensor .name = tensor_name
172+ new_tensor .quantization = quantization
169173 new_tensor_id = len (subgraph .tensors )
170174 subgraph .tensors .append (new_tensor )
171175 return new_tensor_id
@@ -176,6 +180,7 @@ def add_new_activation_tensor(
176180 shape : list [int ],
177181 tensor_type : schema_py_generated .TensorType ,
178182 subgraph : schema_py_generated .SubGraphT ,
183+ quantization : schema_py_generated .QuantizationParametersT | None = None ,
179184) -> int :
180185 """Add a new activation tensor to the model.
181186
@@ -184,6 +189,8 @@ def add_new_activation_tensor(
184189 shape: The shape of the new tensor.
185190 tensor_type: The type of the new tensor.
186191 subgraph: The subgraph where the new tensor is added.
192+ quantization: Optional `QuantizationParametersT` describing the quantization
193+ of this tensor.
187194
188195 Returns:
189196 The index of the new tensor in the subgraph.
@@ -199,6 +206,7 @@ def add_new_activation_tensor(
199206 new_tensor .shape = shape
200207 new_tensor .type = tensor_type
201208 new_tensor .name = tensor_name
209+ new_tensor .quantization = quantization
202210 new_tensor .buffer = 0
203211 new_tensor_id = len (subgraph .tensors )
204212 subgraph .tensors .append (new_tensor )
@@ -226,8 +234,9 @@ def pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
226234 Packed data.
227235 """
228236 if bitwidth == 4 :
229- even_data = flattened_data [::2 ] & 0x0F
230- odd_data = np .left_shift (flattened_data [1 ::2 ], 4 ).astype (np .uint8 )
237+ flattened_data = np .bitwise_and (flattened_data .astype (np .uint8 ), 0x0F )
238+ even_data = flattened_data [::2 ]
239+ odd_data = np .left_shift (flattened_data [1 ::2 ], 4 )
231240 if odd_data .shape [0 ] == even_data .shape [0 ] - 1 :
232241 odd_data = np .pad (odd_data , (0 , 1 ), constant_values = 0 )
233242 return np .bitwise_or (even_data , odd_data )
0 commit comments