Skip to content

Commit eee3a00

Browse files
gonnetcopybara-github
authored andcommitted
Use mmap_utils for more efficient reading/writing of serialized model data.
PiperOrigin-RevId: 884497688
1 parent f56d705 commit eee3a00

File tree

3 files changed

+12
-57
lines changed

3 files changed

+12
-57
lines changed

ai_edge_quantizer/quantizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import os
2626
import io
27+
from ai_edge_litert.tools import mmap_utils
2728
from ai_edge_quantizer import algorithm_manager
2829
from ai_edge_quantizer import calibrator
2930
from ai_edge_quantizer import default_policy
@@ -87,8 +88,7 @@ def save(
8788
pathlib.Path(save_folder) / (model_name + '_recipe.json')
8889
)
8990
recipe = json.dumps(self.recipe)
90-
with open(recipe_save_path, 'w') as output_file_handle:
91-
output_file_handle.write(recipe)
91+
mmap_utils.set_file_contents(recipe_save_path, recipe.encode())
9292

9393
def export_model(self, filepath: Path, overwrite: bool = False) -> None:
9494
"""Exports the quantized model to a .tflite flatbuffer.
@@ -120,8 +120,9 @@ def export_model(self, filepath: Path, overwrite: bool = False) -> None:
120120
' consider change the model name or specify overwrite=True to'
121121
' overwrite the model if needed.'
122122
)
123-
with open(filepath, 'wb') as output_file_handle:
124-
output_file_handle.write(self.quantized_model)
123+
124+
# Try to write the file via an `mmap.mmap` to avoid any buffering.
125+
mmap_utils.set_file_contents(filepath, self.quantized_model)
125126

126127

127128
class Quantizer:
@@ -207,9 +208,8 @@ def load_config_policy(self, filename: Path) -> None:
207208
Args:
208209
filename: Config policy filename.
209210
"""
210-
with open(filename, 'r') as f:
211-
content = f.read()
212-
policy = default_policy.update_default_config_policy(content)
211+
content = bytearray(mmap_utils.get_file_contents(filename)).decode()
212+
policy = default_policy.update_default_config_policy(content)
213213

214214
# Register the policy for MIN_MAX_UNIFORM_QUANT algorithm.
215215
algorithm_manager.register_config_check_policy_func(

ai_edge_quantizer/utils/tfl_flatbuffer_utils.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,13 @@
1717

1818
import collections
1919
import logging
20-
import mmap
21-
import os
2220
import pathlib
2321

2422
import immutabledict
2523
import numpy as np
2624

27-
import os
28-
import io
2925
from ai_edge_litert.tools import flatbuffer_utils
26+
from ai_edge_litert.tools import mmap_utils
3027
from ai_edge_quantizer import qtyping
3128

3229

@@ -151,26 +148,7 @@ def get_model_content(tflite_path: Path) -> memoryview:
151148
Returns:
152149
The model bytes.
153150
"""
154-
model_bytes = None
155-
156-
# Try to mmap the file first if it is local.
157-
try:
158-
if (fd := os.open(tflite_path, os.O_RDONLY)) >= 0:
159-
model_bytes = mmap.mmap(fd, 0, flags=mmap.MAP_SHARED, prot=mmap.PROT_READ)
160-
os.close(fd)
161-
except IOError as e:
162-
logging.info(
163-
'Mapping model file "%s" failed with exception: %s.',
164-
tflite_path,
165-
e,
166-
)
167-
168-
# If mapping failed, go at it conventionally.
169-
if model_bytes is None:
170-
with open(tflite_path, "rb") as tflite_file:
171-
model_bytes = tflite_file.read()
172-
173-
return memoryview(model_bytes)
151+
return mmap_utils.get_file_contents(tflite_path)
174152

175153

176154
def get_model_buffer(tflite_path: Path) -> bytearray:
@@ -182,28 +160,7 @@ def get_model_buffer(tflite_path: Path) -> bytearray:
182160
Returns:
183161
model_buffer: the model buffer.
184162
"""
185-
model_bytearray = None
186-
187-
# Try to mmap the file first if it is local.
188-
try:
189-
if (fd := os.open(tflite_path, os.O_RDONLY)) >= 0:
190-
try:
191-
model_mmap = mmap.mmap(
192-
fd, 0, flags=mmap.MAP_SHARED, prot=mmap.PROT_READ
193-
)
194-
model_bytearray = bytearray(model_mmap[:])
195-
except IOError as e:
196-
print(f"Mapping model file {tflite_path} failed with exception: {e}.")
197-
os.close(fd)
198-
except RuntimeError:
199-
pass
200-
201-
# If mapping failed, go at it conventionally.
202-
if model_bytearray is None:
203-
with open(tflite_path, "rb") as tflite_file:
204-
model_bytearray = bytearray(tflite_file.read())
205-
206-
return model_bytearray
163+
return bytearray(mmap_utils.get_file_contents(tflite_path))
207164

208165

209166
def parse_op_tensors(

ai_edge_quantizer/utils/tfl_interpreter_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
import ml_dtypes
2121
import numpy as np
2222

23+
from ai_edge_litert.tools import mmap_utils
2324
from ai_edge_quantizer import qtyping
2425
from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
2526
from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
26-
import os
27-
import io
2827

2928
DEFAULT_SIGNATURE_KEY = "serving_default"
3029

@@ -52,8 +51,7 @@ def create_tfl_interpreter(
5251
A TFLite interpreter.
5352
"""
5453
if isinstance(tflite_model, str):
55-
with open(tflite_model, "rb") as f:
56-
tflite_model = f.read()
54+
tflite_model = mmap_utils.get_file_contents(tflite_model)
5755

5856
if use_xnnpack:
5957
op_resolver = tfl.OpResolverType.BUILTIN

0 commit comments

Comments
 (0)