Skip to content

Commit 20ad73a

Browse files
rewu93copybara-github
authored andcommitted
Add memoryview support in get_constant_buffer.
PiperOrigin-RevId: 896578088
1 parent 38864d3 commit 20ad73a

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

ai_edge_quantizer/transformations/transformation_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
"""Utility functions for graph transformations."""
1717

18-
import copy
1918
import dataclasses
2019
from typing import Optional, Union
2120

@@ -85,7 +84,7 @@ def add_op_code(
8584

8685

8786
def get_constant_buffer(
88-
data: np.ndarray,
87+
data: np.ndarray | bytes | memoryview,
8988
buffers: list[qtyping.BufferT],
9089
force_duplicate_buffer: bool = False,
9190
) -> int:
@@ -107,12 +106,16 @@ def get_constant_buffer(
107106
# in the case where the data is passed from quantization_params.
108107
new_data = np.ravel(data.view(np.uint8))
109108
elif isinstance(data, bytes):
110-
# in the case where the data is coming from duplicating buffers, we need to
111-
# make a copy of the data to avoid having two buffers pointing to the same
112-
# data.
113-
new_data = copy.deepcopy(data)
109+
# Bytes are readonly, so we can just use them directly as the new data.
110+
new_data = data
111+
elif isinstance(data, memoryview):
112+
new_data = data.toreadonly()
114113
else:
115-
raise ValueError('data passed in must be either np.ndarray or bytes.')
114+
raise ValueError(
115+
'data passed in must be either np.ndarray, bytes, or memoryview.'
116+
' Got: %s'
117+
% type(data)
118+
)
116119
# TODO: b/417811116 - we should make this more efficient.
117120
if not force_duplicate_buffer:
118121
for index, buffer in enumerate(buffers):

0 commit comments

Comments
 (0)