diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index 7656ea3f363..0994156ae50 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -8,10 +8,11 @@ import copy import json +import math import re from dataclasses import dataclass -from typing import ClassVar, List, Literal, Optional, Tuple +from typing import ClassVar, Dict, List, Literal, Optional, Tuple from executorch.exir._serialize._cord import Cord from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass @@ -20,6 +21,10 @@ _program_flatbuffer_to_json, _program_json_to_flatbuffer, ) +from executorch.exir._serialize._named_data_store import ( + BufferEntry, + NamedDataStoreOutput, +) from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required @@ -29,6 +34,7 @@ Buffer, DataLocation, DataSegment, + NamedData, Program, SubsegmentOffsets, ) @@ -41,6 +47,24 @@ _HEADER_BYTEORDER: Literal["little"] = "little" +@dataclass +class AlignedData: + """ + Holds data that should be aligned, for serialization. + + Attributes: + data: The data to serialize, as a cord. + alignment: The alignment required for the data. + """ + + data: Cord + alignment: int + + def __init__(self, data: Cord, alignment: Optional[int] = None) -> None: + self.data = data + self.alignment = alignment or 1 + + def _program_to_json(program: Program) -> str: """Returns the JSON representation of the given Program.""" return json.dumps(program, cls=_DataclassEncoder) @@ -213,7 +237,7 @@ def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]: def _extract_delegate_segments( program: Program, - segments: List[Cord], + segments: List[AlignedData], ) -> None: """Extracts the delegate segments inlined in the program into a list of buffers. The program is modified in-place to remove the delegate data. @@ -253,7 +277,7 @@ def _extract_delegate_segments( segment_index = segment_index_map.get(inline.data) if segment_index is None: segment_index = len(segments) - segments.append(Cord(inline.data)) + segments.append(AlignedData(Cord(inline.data))) segment_index_map[inline.data] = segment_index delegate.processed = BackendDelegateDataReference( location=DataLocation.SEGMENT, @@ -316,6 +340,44 @@ def _extract_constant_segment( return constant_segment_data, constant_segment_offsets +def _extract_named_data( + program: Program, + segments: List[AlignedData], + buffers: List[BufferEntry], + name_to_buffer_idx: Dict[str, int], +) -> None: + """Modifies the program in-place to add references to the named data + segments. + + Args: + program: The program to extract segments from. Modified in-place. + segments: A list of buffers to append extracted segments to. Modified in-place. + buffers: A list of unique buffers and the information required to + serialize them. Not modified. + name_to_buffer_idx: A map from the name of a blob to the index in buffers. + Not modified. + """ + if program.named_data is not None and len(program.named_data) > 0: + raise ValueError("Program already has named data.") + + # Map from buffer_idx to segment_idx. + segment_index_map: Dict[int, int] = {} + + named_data: List[NamedData] = [] + for name, buffer_idx in name_to_buffer_idx.items(): + segment_index = segment_index_map.get(buffer_idx, None) + if segment_index is None: + segment_index = len(segments) + segment_index_map[buffer_idx] = segment_index + segments.append( + AlignedData( + Cord(buffers[buffer_idx].buffer), buffers[buffer_idx].alignment + ) + ) + named_data.append(NamedData(key=name, segment_index=segment_index)) + program.named_data = named_data + + def serialize_pte_binary( program: Program, *, @@ -324,6 +386,7 @@ def serialize_pte_binary( segment_alignment: int = 128, constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, + named_data: Optional[NamedDataStoreOutput] = None, ) -> Cord: """Returns the runtime binary representation of the given Program. @@ -343,6 +406,8 @@ def serialize_pte_binary( delegate_alignment: If provided, the minimum alignment of delegate data in the program. Must be a power of 2. If not provided, uses the value in the schema file. + named_data: If provided, named blobs to be stored in segments + after the PTE file. Returns: The serialized form of the Program, ready for execution by the runtime. """ @@ -355,8 +420,9 @@ def serialize_pte_binary( # copy, reusing the actual data blobs. program = copy.deepcopy(program) - # Store extracted segment data; this may be constant data or delegate data. - segments: List[Cord] = [] + # Store extracted segment data, with any buffer-specific alignment. + # This may be constant data, delegate data or named data. + segments: List[AlignedData] = [] constant_segment_data, constant_segment_offsets = _extract_constant_segment( program.constant_buffer, tensor_alignment=constant_tensor_alignment @@ -374,7 +440,7 @@ def serialize_pte_binary( # Clear the constant buffer, as constant data will be stored in segments. program.constant_buffer = [] # Add to the aggregate segments cord. - segments.append(constant_segment_data) + segments.append(AlignedData(constant_segment_data)) if mutable_data is not None: mutable_segment_data, mutable_segment_offsets = _extract_constant_segment( @@ -389,31 +455,34 @@ def serialize_pte_binary( ), ] # Add to the aggregate segments cord. - segments.append(mutable_segment_data) + segments.append(AlignedData(mutable_segment_data)) if extract_delegate_segments: _extract_delegate_segments(program, segments) + if named_data is not None: + _extract_named_data(program, segments, named_data.buffers, named_data.pte_data) # Append all segments into a single Cord, adding any necessary padding to ensure that # each segment begins at the required alignment. # Update program.segments with the offsets to each segment. segments_data = Cord() - for data in segments: + for segment in segments: prev_end = ( (program.segments[-1].offset + program.segments[-1].size) if program.segments else 0 ) + alignment = math.lcm(segment_alignment, segment.alignment) program.segments.append( DataSegment( - offset=aligned_size(prev_end, segment_alignment), size=len(data) + offset=aligned_size(prev_end, alignment), size=len(segment.data) ) ) # Add to aggregate segments cord with padding. - padding_length = padding_required(len(segments_data), segment_alignment) + padding_length = padding_required(len(segments_data), alignment) if padding_length > 0: segments_data.append(b"\x00" * padding_length) - segments_data.append(data) + segments_data.append(segment.data) # Convert to a standard flatbuffer binary. result: _FlatbufferResult = _program_json_to_flatbuffer( diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py index c311274922f..6351875e113 100644 --- a/exir/_serialize/_serialize.py +++ b/exir/_serialize/_serialize.py @@ -6,12 +6,12 @@ # pyre-strict - -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple from executorch.exir._serialize import _serialize_pte_binary from executorch.exir._serialize._cord import Cord +from executorch.exir._serialize._named_data_store import NamedDataStoreOutput from executorch.exir._serialize.data_serializer import ( DataPayload, DataSerializer, @@ -28,10 +28,24 @@ def serialize_for_executorch( emitter_output: EmitterOutput, config: ExecutorchBackendConfig, data_serializer: DataSerializer, + named_data: Optional[NamedDataStoreOutput] = None, ) -> Tuple[Cord, Dict[str, Cord]]: """Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files.""" # Serialize PTE file. + pte_named_data = None + if ( + named_data is not None + and len(named_data.buffers) > 0 + and len(named_data.pte_data) > 0 + ): + # Create a separate NamedDataStoreOutput with only pte_data; exclude + # external_data, which shouldn't be serialized with the PTE file. + pte_named_data = NamedDataStoreOutput( + buffers=named_data.buffers, + pte_data=named_data.pte_data, + external_data={}, + ) pte: Cord = _serialize_pte_binary( program=emitter_output.program, mutable_data=emitter_output.mutable_data, @@ -39,6 +53,7 @@ def serialize_for_executorch( segment_alignment=config.segment_alignment, constant_tensor_alignment=config.constant_tensor_alignment, delegate_alignment=config.delegate_alignment, + named_data=pte_named_data, ) # Serialize PTD files. @@ -88,4 +103,10 @@ def serialize_for_executorch( ) ) + if named_data is None or len(named_data.external_data) == 0: + return pte, ptd_files + + if len(named_data.buffers) == 0: + raise RuntimeError("External data exists, but there are no buffers provided.") + return pte, ptd_files diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index f20c0b39798..c67849dd28d 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -10,11 +10,16 @@ import copy import difflib import json +import math import unittest from typing import List, Sequence from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json +from executorch.exir._serialize._named_data_store import ( + BufferEntry, + NamedDataStoreOutput, +) from executorch.exir._serialize._program import ( _ExtendedHeader, _get_extended_header, @@ -23,6 +28,7 @@ deserialize_pte_binary, serialize_pte_binary, ) +from executorch.exir._serialize.padding import aligned_size from executorch.exir.schema import ( BackendDelegate, @@ -552,11 +558,9 @@ def test_round_trip_with_segments(self) -> None: # Check the segment base offset boundary. segment_base_offset = eh.segment_base_offset self.assertEqual( - pte_data[segment_base_offset - 2 : segment_base_offset + 3], - # The padding before the first segment. - b"\x00\x00" + pte_data[segment_base_offset : segment_base_offset + 3], # The first few bytes of the first segment. - + b"\x10\x11\x11", + b"\x10\x11\x11", ) # Now that we've shown that the base offset is correct, slice off the @@ -671,7 +675,7 @@ def test_constant_segment_tensor_alignment_non_power_of_2_fails(self) -> None: constant_tensor_alignment=constant_tensor_alignment, ) - def test_constant_segment_and_delegate_segment(self) -> None: + def test_constant_delegate_and_named_data_segments(self) -> None: # Create a program with some constant tensor data and delegate data blobs. program = get_test_program() constant_blobs = ( @@ -682,10 +686,22 @@ def test_constant_segment_and_delegate_segment(self) -> None: self.gen_blob_data(SEGMENT_ALIGNMENT // 2, b"\x30\x33\x03"), self.gen_blob_data(SEGMENT_ALIGNMENT + 1, b"\x40\x44\x04"), ) - add_constant_data(program, constant_blobs) add_delegate_data(program, program.execution_plan[0], delegate_blobs) + # Create named data segment. + named_data_buffers = [ + BufferEntry( + buffer=self.gen_blob_data(8, b"\x50\x55\x05"), alignment=3 + ), # expect lcm(3, 128) = 384 + BufferEntry( + buffer=self.gen_blob_data(16, b"\x60\x66\x06"), alignment=256 + ), # expect lcm(256, 128) = 256 + ] + pte_named_data = {"key0": 0, "key1": 1} + named_data = NamedDataStoreOutput( + buffers=named_data_buffers, pte_data=pte_named_data, external_data={} + ) # Extract the blobs into segments during serialization. pte_data = bytes( serialize_pte_binary( @@ -693,6 +709,7 @@ def test_constant_segment_and_delegate_segment(self) -> None: extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, + named_data=named_data, ) ) @@ -702,6 +719,7 @@ def test_constant_segment_and_delegate_segment(self) -> None: program.execution_plan[0].delegates[0].processed.location, DataLocation.INLINE, ) + self.assertEqual(program.named_data, []) # Extended header should be present in the serialized data. eh = self.get_and_validate_extended_header(pte_data) @@ -715,9 +733,12 @@ def test_constant_segment_and_delegate_segment(self) -> None: # Peek inside the actual flatbuffer data to see the segments. program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) - # Segment table should contain a constant segment and the delegate blobs. + # Segment table should contain a constant segment, the delegate blobs + # and a named data segment. segment_table: List[DataSegment] = program_with_segments.segments - self.assertEqual(len(segment_table), len(delegate_blobs) + 1) + self.assertEqual( + len(segment_table), len(delegate_blobs) + len(pte_named_data) + 1 + ) self.assertEqual(segment_table[0].offset, 0) # segment_table[0] is the constant segment, which # contains a couple of tensors with sizes: @@ -728,6 +749,30 @@ def test_constant_segment_and_delegate_segment(self) -> None: self.assertEqual(segment_table[1].size, SEGMENT_ALIGNMENT // 2) self.assertEqual(segment_table[2].offset, SEGMENT_ALIGNMENT * 2) self.assertEqual(segment_table[2].size, SEGMENT_ALIGNMENT + 1) + # Named data segments. + expected_offset = aligned_size( + (segment_table[2].offset + segment_table[2].size), + math.lcm(named_data_buffers[0].alignment, SEGMENT_ALIGNMENT), + ) + self.assertEqual(segment_table[3].offset, expected_offset) + self.assertEqual(segment_table[3].size, len(named_data_buffers[0].buffer)) + expected_offset = aligned_size( + (segment_table[3].offset + segment_table[3].size), + math.lcm(named_data_buffers[1].alignment, SEGMENT_ALIGNMENT), + ) + self.assertEqual(segment_table[4].offset, expected_offset) + self.assertEqual(segment_table[4].size, len(named_data_buffers[1].buffer)) + + # Named data. + self.assertTrue(program_with_segments.named_data is not None) + program_named_data = program_with_segments.named_data + self.assertEqual(len(program_named_data), len(pte_named_data)) + + # Check named data values. + self.assertEqual(program_named_data[0].key, "key0") + self.assertEqual(program_named_data[0].segment_index, 3) + self.assertEqual(program_named_data[1].key, "key1") + self.assertEqual(program_named_data[1].segment_index, 4) # Check constant_segment index and offsets. subsegment_offsets: SubsegmentOffsets = program_with_segments.constant_segment @@ -811,6 +856,23 @@ def test_constant_segment_and_delegate_segment(self) -> None: + b"\x40\x44\x44", ) + # Check named data segments + self.assertEqual( + segment_data[ + segment_table[3].offset : segment_table[3].offset + + segment_table[3].size + ], + named_data_buffers[0].buffer, + ) + + self.assertEqual( + segment_data[ + segment_table[4].offset : segment_table[4].offset + + segment_table[4].size + ], + named_data_buffers[1].buffer, + ) + # Convert back. program2 = deserialize_pte_binary(pte_data) # Programs are the same besides constant_buffer, as deserialization @@ -820,6 +882,104 @@ def test_constant_segment_and_delegate_segment(self) -> None: # Number of constant tensors should be the same. self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + def test_named_data_segments(self) -> None: + # Set segment alignment to 12 to test the padding. + SEGMENT_ALIGNMENT: int = 12 + + # Create a program with some named data segments. + program = get_test_program() + + # Create named data segments with different alignments. + buffers = [ + BufferEntry( + buffer=self.gen_blob_data(8, b"\x10\x11\x01"), alignment=8 + ), # expect lcm(8, 12) = 24 + BufferEntry( + buffer=self.gen_blob_data(16, b"\x20\x22\x02"), alignment=32 + ), # expect lcm(32, 12) = 96 + BufferEntry( + buffer=self.gen_blob_data(24, b"\x30\x33\x03"), alignment=24 + ), # expect lcm(24, 12) = 24 + ] + pte_named_data = {"key1": 0, "key2": 0, "key3": 1, "key4": 2} + named_data = NamedDataStoreOutput( + buffers=buffers, pte_data=pte_named_data, external_data={} + ) + # Serialize the program with named data segments. + pte_data = bytes( + serialize_pte_binary( + program, + extract_delegate_segments=True, + segment_alignment=SEGMENT_ALIGNMENT, + constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, + named_data=named_data, + ) + ) + + # named_data is initially empty. + self.assertEqual(program.named_data, []) + # Extended header should be present in the serialized data. + eh = self.get_and_validate_extended_header(pte_data) + # Segment offset should be non-zero since there are segments. It + # should point past the end of the program data, but not beyond + # the end of the file. + self.assertGreaterEqual(eh.segment_base_offset, eh.program_size) + self.assertLess(eh.segment_base_offset, len(pte_data)) + + # Peek inside the actual flatbuffer data to see the named data segments. + program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + # pyre-ignore Incompatible parameter type [6] + self.assertEqual(len(program_with_segments.named_data), len(pte_named_data)) + + # Check Program.named_data values. + # pyre-ignore Undefined attribute [16] + self.assertEqual(program_with_segments.named_data[0].key, "key1") + self.assertEqual(program_with_segments.named_data[0].segment_index, 0) + self.assertEqual(program_with_segments.named_data[1].key, "key2") + self.assertEqual(program_with_segments.named_data[1].segment_index, 0) + self.assertEqual(program_with_segments.named_data[2].key, "key3") + self.assertEqual(program_with_segments.named_data[2].segment_index, 1) + self.assertEqual(program_with_segments.named_data[3].key, "key4") + self.assertEqual(program_with_segments.named_data[3].segment_index, 2) + + # Check Program.segments values. + segment_table: List[DataSegment] = program_with_segments.segments + self.assertEqual(len(segment_table), 3) + + for i in range(len(segment_table)): + segment_length = ( + segment_table[i - 1].offset + segment_table[i - 1].size if i > 0 else 0 + ) + expected_offset = aligned_size( + segment_length, math.lcm(SEGMENT_ALIGNMENT, buffers[i].alignment) + ) + self.assertEqual(segment_table[i].offset, expected_offset) + self.assertEqual(segment_table[i].size, len(buffers[i].buffer)) + + # Check the pte data for buffer values. + segment_data: bytes = pte_data[eh.segment_base_offset :] + self.assertEqual( + segment_data[ + segment_table[0].offset : segment_table[0].offset + + segment_table[0].size + ], + buffers[0].buffer, + ) + self.assertEqual( + segment_data[ + segment_table[1].offset : segment_table[1].offset + + segment_table[1].size + ], + buffers[1].buffer, + ) + self.assertEqual( + segment_data[ + segment_table[2].offset : segment_table[2].offset + + segment_table[2].size + ], + buffers[2].buffer, + ) + # Common data for extended header tests. The two example values should produce # the example data. diff --git a/exir/program/_program.py b/exir/program/_program.py index fdf4b93e19c..7e0c3379f2a 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -16,6 +16,10 @@ import torch import torch._export from executorch.exir._serialize._cord import Cord +from executorch.exir._serialize._named_data_store import ( + NamedDataStore, + NamedDataStoreOutput, +) from executorch.exir._serialize._serialize import serialize_for_executorch from executorch.exir._serialize.data_serializer import DataSerializer from executorch.exir._warnings import experimental @@ -1259,6 +1263,8 @@ def __init__( self._edge_programs: Dict[str, ExportedProgram] = edge_programs self._config_methods = constant_methods + self._named_data_store = NamedDataStore() + @property def methods(self) -> Set[str]: """ @@ -1444,7 +1450,10 @@ def to_executorch( execution_programs[name] = program return ExecutorchProgramManager( - execution_programs, self._config_methods, config + execution_programs, + self._config_methods, + config, + self._named_data_store.get_named_data_store_output(), ) @@ -1465,6 +1474,7 @@ def __init__( execution_programs: Dict[str, ExportedProgram], config_methods: Optional[Dict[str, Any]] = None, backend_config: Optional[ExecutorchBackendConfig] = None, + named_data: Optional[NamedDataStoreOutput] = None, ): """ End users should not call this constructor directly. Instead, they should use @@ -1487,6 +1497,9 @@ def __init__( self._execution_programs: Dict[str, ExportedProgram] = execution_programs self._config_methods: Optional[Dict[str, Any]] = config_methods + # Named data from EdgeProgramManager + self._named_data: Optional[NamedDataStoreOutput] = named_data + backend_config = backend_config or ExecutorchBackendConfig() # Emit methods @@ -1499,7 +1512,10 @@ def __init__( # Serialize emitter output, ready to be written to a file. self._data_serializer = FlatTensorSerializer() self._pte_data, self._tensor_data = serialize_for_executorch( - self._emitter_output, backend_config, self._data_serializer + self._emitter_output, + backend_config, + self._data_serializer, + self._named_data, ) self._buffer: Optional[bytes] = None diff --git a/exir/tests/common.py b/exir/tests/common.py index fdd7a3adca4..daeea109667 100644 --- a/exir/tests/common.py +++ b/exir/tests/common.py @@ -79,6 +79,7 @@ def get_test_program() -> Program: backend_delegate_data=[], segments=[], constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]), + named_data=[], )