Skip to content

Commit f35422c

Browse files
authored
Serialize NamedData in PTE file (#8835)
1. Serialize NamedData in PTE file 2. Add NamedDataStore to EdgeProgramManager --- Serializing NamedData is slightly different to constant/delegate data as each segment comes with its own alignment. **An example:** Given NamedData = {"key": data}. Data is 250 bytes. - BackendA requires data with alignment=3 - BackendB requires data with alignment=4 Then, data0 should be serialized with alignment of lcm(3, 4) = 12 At serialization, ExecuTorch has a 'segment_alignment' that defaults to 128. Data is now serialized to lcm(12, 128) = 384. Inside the DataSegment, we want to store the original size of the data (250). The offset of the subsequent DataSegment would be 384 bytes after the start of this one. **Design** Introduce a new dataclass 'AlignedData' that stores the buffer and any alignment that's required. This is used when assembling Program.segments to ensure we get lcm(buffer_alignment, segment_alignment). Note: The default segment_alignment can be overridden inside 'ExecutorchBackendConfig'. Differential Revision: [D69764150](https://our.internmc.facebook.com/intern/diff/D69764150/) [ghstack-poisoned]
1 parent 4df0ade commit f35422c

File tree

5 files changed

+290
-23
lines changed

5 files changed

+290
-23
lines changed

exir/_serialize/_program.py

+80-11
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
import copy
1010
import json
11+
import math
1112
import re
1213

1314
from dataclasses import dataclass
14-
from typing import ClassVar, List, Literal, Optional, Tuple
15+
from typing import ClassVar, Dict, List, Literal, Optional, Tuple
1516

1617
from executorch.exir._serialize._cord import Cord
1718
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
@@ -20,6 +21,10 @@
2021
_program_flatbuffer_to_json,
2122
_program_json_to_flatbuffer,
2223
)
24+
from executorch.exir._serialize._named_data_store import (
25+
BufferEntry,
26+
NamedDataStoreOutput,
27+
)
2328

2429
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
2530

@@ -29,6 +34,7 @@
2934
Buffer,
3035
DataLocation,
3136
DataSegment,
37+
NamedData,
3238
Program,
3339
SubsegmentOffsets,
3440
)
@@ -41,6 +47,24 @@
4147
_HEADER_BYTEORDER: Literal["little"] = "little"
4248

4349

50+
@dataclass
51+
class AlignedData:
52+
"""
53+
Holds data that should be aligned, for serialization.
54+
55+
Attributes:
56+
data: The data to serialize, as a cord.
57+
alignment: The alignment required for the data.
58+
"""
59+
60+
data: Cord
61+
alignment: int
62+
63+
def __init__(self, data: Cord, alignment: Optional[int] = None) -> None:
64+
self.data = data
65+
self.alignment = alignment or 1
66+
67+
4468
def _program_to_json(program: Program) -> str:
4569
"""Returns the JSON representation of the given Program."""
4670
return json.dumps(program, cls=_DataclassEncoder)
@@ -213,7 +237,7 @@ def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
213237

214238
def _extract_delegate_segments(
215239
program: Program,
216-
segments: List[Cord],
240+
segments: List[AlignedData],
217241
) -> None:
218242
"""Extracts the delegate segments inlined in the program into a list of buffers.
219243
The program is modified in-place to remove the delegate data.
@@ -253,7 +277,7 @@ def _extract_delegate_segments(
253277
segment_index = segment_index_map.get(inline.data)
254278
if segment_index is None:
255279
segment_index = len(segments)
256-
segments.append(Cord(inline.data))
280+
segments.append(AlignedData(Cord(inline.data)))
257281
segment_index_map[inline.data] = segment_index
258282
delegate.processed = BackendDelegateDataReference(
259283
location=DataLocation.SEGMENT,
@@ -316,6 +340,44 @@ def _extract_constant_segment(
316340
return constant_segment_data, constant_segment_offsets
317341

318342

343+
def _extract_named_data(
344+
program: Program,
345+
segments: List[AlignedData],
346+
buffers: List[BufferEntry],
347+
name_to_buffer_idx: Dict[str, int],
348+
) -> None:
349+
"""Modifies the program in-place to add references to the named data
350+
segments.
351+
352+
Args:
353+
program: The program to extract segments from. Modified in-place.
354+
segments: A list of buffers to append extracted segments to. Modified in-place.
355+
buffers: A list of unique buffers and the information required to
356+
serialize them. Not modified.
357+
name_to_buffer_idx: A map from the name of a blob to the index in buffers.
358+
Not modified.
359+
"""
360+
if program.named_data is not None and len(program.named_data) > 0:
361+
raise ValueError("Program already has named data.")
362+
363+
# Map from buffer_idx to segment_idx.
364+
segment_index_map: Dict[int, int] = {}
365+
366+
named_data: List[NamedData] = []
367+
for name, buffer_idx in name_to_buffer_idx.items():
368+
segment_index = segment_index_map.get(buffer_idx, None)
369+
if segment_index is None:
370+
segment_index = len(segments)
371+
segment_index_map[buffer_idx] = segment_index
372+
segments.append(
373+
AlignedData(
374+
Cord(buffers[buffer_idx].buffer), buffers[buffer_idx].alignment
375+
)
376+
)
377+
named_data.append(NamedData(key=name, segment_index=segment_index))
378+
program.named_data = named_data
379+
380+
319381
def serialize_pte_binary(
320382
program: Program,
321383
*,
@@ -324,6 +386,7 @@ def serialize_pte_binary(
324386
segment_alignment: int = 128,
325387
constant_tensor_alignment: Optional[int] = None,
326388
delegate_alignment: Optional[int] = None,
389+
named_data: Optional[NamedDataStoreOutput] = None,
327390
) -> Cord:
328391
"""Returns the runtime binary representation of the given Program.
329392
@@ -343,6 +406,8 @@ def serialize_pte_binary(
343406
delegate_alignment: If provided, the minimum alignment of delegate data
344407
in the program. Must be a power of 2. If not provided, uses the
345408
value in the schema file.
409+
named_data: If provided, named blobs to be stored in segments
410+
after the PTE file.
346411
Returns:
347412
The serialized form of the Program, ready for execution by the runtime.
348413
"""
@@ -355,8 +420,9 @@ def serialize_pte_binary(
355420
# copy, reusing the actual data blobs.
356421
program = copy.deepcopy(program)
357422

358-
# Store extracted segment data; this may be constant data or delegate data.
359-
segments: List[Cord] = []
423+
# Store extracted segment data, with any buffer-specific alignment.
424+
# This may be constant data, delegate data or named data.
425+
segments: List[AlignedData] = []
360426

361427
constant_segment_data, constant_segment_offsets = _extract_constant_segment(
362428
program.constant_buffer, tensor_alignment=constant_tensor_alignment
@@ -374,7 +440,7 @@ def serialize_pte_binary(
374440
# Clear the constant buffer, as constant data will be stored in segments.
375441
program.constant_buffer = []
376442
# Add to the aggregate segments cord.
377-
segments.append(constant_segment_data)
443+
segments.append(AlignedData(constant_segment_data))
378444

379445
if mutable_data is not None:
380446
mutable_segment_data, mutable_segment_offsets = _extract_constant_segment(
@@ -389,31 +455,34 @@ def serialize_pte_binary(
389455
),
390456
]
391457
# Add to the aggregate segments cord.
392-
segments.append(mutable_segment_data)
458+
segments.append(AlignedData(mutable_segment_data))
393459

394460
if extract_delegate_segments:
395461
_extract_delegate_segments(program, segments)
462+
if named_data is not None:
463+
_extract_named_data(program, segments, named_data.buffers, named_data.pte_data)
396464

397465
# Append all segments into a single Cord, adding any necessary padding to ensure that
398466
# each segment begins at the required alignment.
399467
# Update program.segments with the offsets to each segment.
400468
segments_data = Cord()
401-
for data in segments:
469+
for segment in segments:
402470
prev_end = (
403471
(program.segments[-1].offset + program.segments[-1].size)
404472
if program.segments
405473
else 0
406474
)
475+
alignment = math.lcm(segment_alignment, segment.alignment)
407476
program.segments.append(
408477
DataSegment(
409-
offset=aligned_size(prev_end, segment_alignment), size=len(data)
478+
offset=aligned_size(prev_end, alignment), size=len(segment.data)
410479
)
411480
)
412481
# Add to aggregate segments cord with padding.
413-
padding_length = padding_required(len(segments_data), segment_alignment)
482+
padding_length = padding_required(len(segments_data), alignment)
414483
if padding_length > 0:
415484
segments_data.append(b"\x00" * padding_length)
416-
segments_data.append(data)
485+
segments_data.append(segment.data)
417486

418487
# Convert to a standard flatbuffer binary.
419488
result: _FlatbufferResult = _program_json_to_flatbuffer(

exir/_serialize/_serialize.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
# pyre-strict
88

9-
10-
from typing import Dict, Tuple
9+
from typing import Dict, Optional, Tuple
1110

1211
from executorch.exir._serialize import _serialize_pte_binary
1312

1413
from executorch.exir._serialize._cord import Cord
14+
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
1515
from executorch.exir._serialize.data_serializer import (
1616
DataPayload,
1717
DataSerializer,
@@ -28,17 +28,32 @@ def serialize_for_executorch(
2828
emitter_output: EmitterOutput,
2929
config: ExecutorchBackendConfig,
3030
data_serializer: DataSerializer,
31+
named_data: Optional[NamedDataStoreOutput] = None,
3132
) -> Tuple[Cord, Dict[str, Cord]]:
3233
"""Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files."""
3334

3435
# Serialize PTE file.
36+
pte_named_data = None
37+
if (
38+
named_data is not None
39+
and len(named_data.buffers) > 0
40+
and len(named_data.pte_data) > 0
41+
):
42+
# Create a separate NamedDataStoreOutput with only pte_data; exclude
43+
# external_data, which shouldn't be serialized with the PTE file.
44+
pte_named_data = NamedDataStoreOutput(
45+
buffers=named_data.buffers,
46+
pte_data=named_data.pte_data,
47+
external_data={},
48+
)
3549
pte: Cord = _serialize_pte_binary(
3650
program=emitter_output.program,
3751
mutable_data=emitter_output.mutable_data,
3852
extract_delegate_segments=config.extract_delegate_segments,
3953
segment_alignment=config.segment_alignment,
4054
constant_tensor_alignment=config.constant_tensor_alignment,
4155
delegate_alignment=config.delegate_alignment,
56+
named_data=pte_named_data,
4257
)
4358

4459
# Serialize PTD files.
@@ -88,4 +103,10 @@ def serialize_for_executorch(
88103
)
89104
)
90105

106+
if named_data is None or len(named_data.external_data) == 0:
107+
return pte, ptd_files
108+
109+
if len(named_data.buffers) == 0:
110+
raise RuntimeError("External data exists, but there are no buffers provided.")
111+
91112
return pte, ptd_files

0 commit comments

Comments
 (0)