1515
1616"""Utility functions for graph transformations."""
1717
18- import copy
1918import dataclasses
2019from typing import Optional , Union
2120
@@ -85,7 +84,7 @@ def add_op_code(
8584
8685
8786def get_constant_buffer (
88- data : np .ndarray ,
87+ data : np .ndarray | bytes | memoryview ,
8988 buffers : list [qtyping .BufferT ],
9089 force_duplicate_buffer : bool = False ,
9190) -> int :
@@ -107,12 +106,16 @@ def get_constant_buffer(
107106 # in the case where the data is passed from quantization_params.
108107 new_data = np .ravel (data .view (np .uint8 ))
109108 elif isinstance (data , bytes ):
110- # in the case where the data is coming from duplicating buffers, we need to
111- # make a copy of the data to avoid having two buffers pointing to the same
112- # data.
113- new_data = copy . deepcopy ( data )
109+ # Bytes are readonly, so we can just use them directly as the new data.
110+ new_data = data
111+ elif isinstance ( data , memoryview ):
112+ new_data = data . toreadonly ( )
114113 else :
115- raise ValueError ('data passed in must be either np.ndarray or bytes.' )
114+ raise ValueError (
115+ 'data passed in must be either np.ndarray, bytes, or memoryview.'
116+ ' Got: %s'
117+ % type (data )
118+ )
116119 # TODO: b/417811116 - we should make this more efficient.
117120 if not force_duplicate_buffer :
118121 for index , buffer in enumerate (buffers ):
0 commit comments