Skip to content

Serialize NamedData in PTE file #8847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 80 additions & 11 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import copy
import json
import math
import re

from dataclasses import dataclass
from typing import ClassVar, List, Literal, Optional, Tuple
from typing import ClassVar, Dict, List, Literal, Optional, Tuple

from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
Expand All @@ -20,6 +21,10 @@
_program_flatbuffer_to_json,
_program_json_to_flatbuffer,
)
from executorch.exir._serialize._named_data_store import (
BufferEntry,
NamedDataStoreOutput,
)

from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required

Expand All @@ -29,6 +34,7 @@
Buffer,
DataLocation,
DataSegment,
NamedData,
Program,
SubsegmentOffsets,
)
Expand All @@ -41,6 +47,24 @@
_HEADER_BYTEORDER: Literal["little"] = "little"


@dataclass
class AlignedData:
"""
Holds data that should be aligned, for serialization.

Attributes:
data: The data to serialize, as a cord.
alignment: The alignment required for the data.
"""

data: Cord
alignment: int

def __init__(self, data: Cord, alignment: Optional[int] = None) -> None:
self.data = data
self.alignment = alignment or 1


def _program_to_json(program: Program) -> str:
"""Returns the JSON representation of the given Program."""
return json.dumps(program, cls=_DataclassEncoder)
Expand Down Expand Up @@ -213,7 +237,7 @@ def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:

def _extract_delegate_segments(
program: Program,
segments: List[Cord],
segments: List[AlignedData],
) -> None:
"""Extracts the delegate segments inlined in the program into a list of buffers.
The program is modified in-place to remove the delegate data.
Expand Down Expand Up @@ -253,7 +277,7 @@ def _extract_delegate_segments(
segment_index = segment_index_map.get(inline.data)
if segment_index is None:
segment_index = len(segments)
segments.append(Cord(inline.data))
segments.append(AlignedData(Cord(inline.data)))
segment_index_map[inline.data] = segment_index
delegate.processed = BackendDelegateDataReference(
location=DataLocation.SEGMENT,
Expand Down Expand Up @@ -316,6 +340,44 @@ def _extract_constant_segment(
return constant_segment_data, constant_segment_offsets


def _extract_named_data(
program: Program,
segments: List[AlignedData],
buffers: List[BufferEntry],
name_to_buffer_idx: Dict[str, int],
) -> None:
"""Modifies the program in-place to add references to the named data
segments.

Args:
program: The program to extract segments from. Modified in-place.
segments: A list of buffers to append extracted segments to. Modified in-place.
buffers: A list of unique buffers and the information required to
serialize them. Not modified.
name_to_buffer_idx: A map from the name of a blob to the index in buffers.
Not modified.
"""
if program.named_data is not None and len(program.named_data) > 0:
raise ValueError("Program already has named data.")

# Map from buffer_idx to segment_idx.
segment_index_map: Dict[int, int] = {}

named_data: List[NamedData] = []
for name, buffer_idx in name_to_buffer_idx.items():
segment_index = segment_index_map.get(buffer_idx, None)
if segment_index is None:
segment_index = len(segments)
segment_index_map[buffer_idx] = segment_index
segments.append(
AlignedData(
Cord(buffers[buffer_idx].buffer), buffers[buffer_idx].alignment
)
)
named_data.append(NamedData(key=name, segment_index=segment_index))
program.named_data = named_data


def serialize_pte_binary(
program: Program,
*,
Expand All @@ -324,6 +386,7 @@ def serialize_pte_binary(
segment_alignment: int = 128,
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
named_data: Optional[NamedDataStoreOutput] = None,
) -> Cord:
"""Returns the runtime binary representation of the given Program.

Expand All @@ -343,6 +406,8 @@ def serialize_pte_binary(
delegate_alignment: If provided, the minimum alignment of delegate data
in the program. Must be a power of 2. If not provided, uses the
value in the schema file.
named_data: If provided, named blobs to be stored in segments
after the PTE file.
Returns:
The serialized form of the Program, ready for execution by the runtime.
"""
Expand All @@ -355,8 +420,9 @@ def serialize_pte_binary(
# copy, reusing the actual data blobs.
program = copy.deepcopy(program)

# Store extracted segment data; this may be constant data or delegate data.
segments: List[Cord] = []
# Store extracted segment data, with any buffer-specific alignment.
# This may be constant data, delegate data or named data.
segments: List[AlignedData] = []

constant_segment_data, constant_segment_offsets = _extract_constant_segment(
program.constant_buffer, tensor_alignment=constant_tensor_alignment
Expand All @@ -374,7 +440,7 @@ def serialize_pte_binary(
# Clear the constant buffer, as constant data will be stored in segments.
program.constant_buffer = []
# Add to the aggregate segments cord.
segments.append(constant_segment_data)
segments.append(AlignedData(constant_segment_data))

if mutable_data is not None:
mutable_segment_data, mutable_segment_offsets = _extract_constant_segment(
Expand All @@ -389,31 +455,34 @@ def serialize_pte_binary(
),
]
# Add to the aggregate segments cord.
segments.append(mutable_segment_data)
segments.append(AlignedData(mutable_segment_data))

if extract_delegate_segments:
_extract_delegate_segments(program, segments)
if named_data is not None:
_extract_named_data(program, segments, named_data.buffers, named_data.pte_data)

# Append all segments into a single Cord, adding any necessary padding to ensure that
# each segment begins at the required alignment.
# Update program.segments with the offsets to each segment.
segments_data = Cord()
for data in segments:
for segment in segments:
prev_end = (
(program.segments[-1].offset + program.segments[-1].size)
if program.segments
else 0
)
alignment = math.lcm(segment_alignment, segment.alignment)
program.segments.append(
DataSegment(
offset=aligned_size(prev_end, segment_alignment), size=len(data)
offset=aligned_size(prev_end, alignment), size=len(segment.data)
)
)
# Add to aggregate segments cord with padding.
padding_length = padding_required(len(segments_data), segment_alignment)
padding_length = padding_required(len(segments_data), alignment)
if padding_length > 0:
segments_data.append(b"\x00" * padding_length)
segments_data.append(data)
segments_data.append(segment.data)

# Convert to a standard flatbuffer binary.
result: _FlatbufferResult = _program_json_to_flatbuffer(
Expand Down
25 changes: 23 additions & 2 deletions exir/_serialize/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

# pyre-strict


from typing import Dict, Tuple
from typing import Dict, Optional, Tuple

from executorch.exir._serialize import _serialize_pte_binary

from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
from executorch.exir._serialize.data_serializer import (
DataPayload,
DataSerializer,
Expand All @@ -28,17 +28,32 @@ def serialize_for_executorch(
emitter_output: EmitterOutput,
config: ExecutorchBackendConfig,
data_serializer: DataSerializer,
named_data: Optional[NamedDataStoreOutput] = None,
) -> Tuple[Cord, Dict[str, Cord]]:
"""Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files."""

# Serialize PTE file.
pte_named_data = None
if (
named_data is not None
and len(named_data.buffers) > 0
and len(named_data.pte_data) > 0
):
# Create a separate NamedDataStoreOutput with only pte_data; exclude
# external_data, which shouldn't be serialized with the PTE file.
pte_named_data = NamedDataStoreOutput(
buffers=named_data.buffers,
pte_data=named_data.pte_data,
external_data={},
)
pte: Cord = _serialize_pte_binary(
program=emitter_output.program,
mutable_data=emitter_output.mutable_data,
extract_delegate_segments=config.extract_delegate_segments,
segment_alignment=config.segment_alignment,
constant_tensor_alignment=config.constant_tensor_alignment,
delegate_alignment=config.delegate_alignment,
named_data=pte_named_data,
)

# Serialize PTD files.
Expand Down Expand Up @@ -88,4 +103,10 @@ def serialize_for_executorch(
)
)

if named_data is None or len(named_data.external_data) == 0:
return pte, ptd_files

if len(named_data.buffers) == 0:
raise RuntimeError("External data exists, but there are no buffers provided.")

return pte, ptd_files
Loading
Loading