Skip to content

Commit c0fd26d

Browse files
committed
[executorch][flat_tensor] Serialize flat tensor
Pull Request resolved: #7268 Serialize a flat tensor file. The resulting file looks like: Header containing: - flatbuffer offset and size - segment data offset and size Flatbuffer containing: - Items described in [flat_tensor.fbs](https://www.internalfb.com/code/fbsource/[079ba95593be856a16783bd3f3b3579580595fbb]/fbcode/executorch/extension/flat_tensor/flat_tensor.fbs) Tensor data (in segment) - Raw tensor data ghstack-source-id: 260056750 @exported-using-ghexport Differential Revision: [D66374253](https://our.internmc.facebook.com/intern/diff/D66374253/)
1 parent 976e008 commit c0fd26d

File tree

7 files changed

+398
-0
lines changed

7 files changed

+398
-0
lines changed

exir/_serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ runtime.python_library(
3333
"_dataclass.py",
3434
"_flatbuffer.py",
3535
"_program.py",
36+
"data_serializer.py",
3637
"padding.py",
3738
],
3839
resources = {

extension/flat_tensor/__init__.py

Whitespace-only changes.

extension/flat_tensor/serialize/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,21 @@ runtime.python_library(
1414
"//executorch/...",
1515
],
1616
)
17+
18+
runtime.python_library(
19+
name = "serialize",
20+
srcs = [
21+
"serialize.py",
22+
],
23+
resources = [
24+
"flat_tensor.fbs",
25+
"scalar_type.fbs",
26+
],
27+
visibility = [
28+
"//executorch/...",
29+
],
30+
deps = [
31+
":schema",
32+
"//executorch/exir/_serialize:lib",
33+
],
34+
)

extension/flat_tensor/serialize/__init__.py

Whitespace-only changes.
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
import json
2+
import os
3+
import tempfile
4+
from dataclasses import dataclass
5+
from typing import ClassVar, Dict, List, Literal, Optional
6+
7+
import pkg_resources
8+
from executorch.exir._serialize._cord import Cord
9+
from executorch.exir._serialize._dataclass import _DataclassEncoder
10+
11+
from executorch.exir._serialize._flatbuffer import _flatc_compile
12+
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
13+
14+
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
15+
16+
# Byte order of numbers written to flat tensor headers. Always little-endian
17+
# regardless of the host system, since all commonly-used modern CPUs are little
18+
# endian.
19+
_HEADER_BYTEORDER: Literal["little"] = "little"
20+
21+
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import (
22+
DataSegment,
23+
FlatTensor,
24+
TensorMetadata,
25+
)
26+
27+
28+
def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
29+
"""Converts a FlatTensor to a flatbuffer and returns the serialized data."""
30+
flat_tensor_json = json.dumps(flat_tensor, cls=_DataclassEncoder)
31+
with tempfile.TemporaryDirectory() as d:
32+
schema_path = os.path.join(d, "flat_tensor.fbs")
33+
with open(schema_path, "wb") as schema_file:
34+
schema_file.write(
35+
pkg_resources.resource_string(__name__, "flat_tensor.fbs")
36+
)
37+
scalar_type_path = os.path.join(d, "scalar_type.fbs")
38+
with open(scalar_type_path, "wb") as scalar_type_file:
39+
scalar_type_file.write(
40+
pkg_resources.resource_string(__name__, "scalar_type.fbs")
41+
)
42+
json_path = os.path.join(d, "flat_tensor.json")
43+
with open(json_path, "wb") as json_file:
44+
json_file.write(flat_tensor_json.encode("ascii"))
45+
46+
_flatc_compile(d, schema_path, json_path)
47+
output_path = os.path.join(d, "flat_tensor.ptd")
48+
with open(output_path, "rb") as output_file:
49+
return Cord(output_file.read())
50+
51+
52+
@dataclass
53+
class FlatTensorConfig:
54+
tensor_alignment: int = 16
55+
segment_alignment: int = 16
56+
57+
58+
@dataclass
59+
class FlatTensorHeader:
60+
# Class constants.
61+
# The magic bytes that should be at the beginning of the header.
62+
EXPECTED_MAGIC: ClassVar[bytes] = b"FH01"
63+
EXPECTED_LENGTH: ClassVar[int] = (
64+
# Header magic
65+
4
66+
# Header length
67+
+ 4
68+
# Flatbuffer offset
69+
+ 8
70+
# Flatbuffer data size
71+
+ 8
72+
# Segment base offset
73+
+ 8
74+
# Data size
75+
+ 8
76+
)
77+
78+
# Instance attributes. @dataclass will turn these into ctor args.
79+
80+
# Offset to the start of the flatbuffer data, in bytes.
81+
flatbuffer_offset: int
82+
# The size of the serialized data in bytes.
83+
flatbuffer_size: int
84+
# Offset to the start of the first segment, or zero if there
85+
# are no segments.
86+
segment_base_offset: int
87+
# Size of all the segment data, in bytes.
88+
segment_data_size: int
89+
90+
# The magic bytes read from or to be written to the binary header.
91+
magic: bytes = EXPECTED_MAGIC
92+
# The header length, in bytes, read from or to be written to the binary
93+
# header.
94+
length: int = EXPECTED_LENGTH
95+
96+
@staticmethod
97+
def from_bytes(data: bytes) -> "FlatTensorHeader":
98+
"""Tries to read an flat_tensor header from the provided data.
99+
100+
Does not validate that the header is well-formed. Callers should
101+
use is_valid().
102+
103+
Args:
104+
data: The data to read from.
105+
Returns:
106+
The contents of the flat_tensor header.
107+
Raises:
108+
ValueError: If not enough data is provided.
109+
"""
110+
if len(data) < FlatTensorHeader.EXPECTED_LENGTH:
111+
raise ValueError(
112+
f"Not enough data for flat_tensor header: {len(data)} "
113+
+ f"< {FlatTensorHeader.EXPECTED_LENGTH}"
114+
)
115+
116+
return FlatTensorHeader(
117+
magic=data[0:4],
118+
length=int.from_bytes(data[4:8], byteorder=_HEADER_BYTEORDER),
119+
flatbuffer_offset=int.from_bytes(data[8:16], byteorder=_HEADER_BYTEORDER),
120+
flatbuffer_size=int.from_bytes(data[16:24], byteorder=_HEADER_BYTEORDER),
121+
segment_base_offset=int.from_bytes(
122+
data[24:32], byteorder=_HEADER_BYTEORDER
123+
),
124+
segment_data_size=int.from_bytes(data[32:40], byteorder=_HEADER_BYTEORDER),
125+
)
126+
127+
def is_valid(self) -> bool:
128+
"""Returns true if the flat_tensor header appears to be well-formed."""
129+
return (
130+
self.magic == FlatTensorHeader.EXPECTED_MAGIC
131+
and self.length >= FlatTensorHeader.EXPECTED_LENGTH
132+
)
133+
134+
def to_bytes(self) -> bytes:
135+
"""Returns the binary representation of the flat_tensor header.
136+
137+
Note that this will ignore self.magic and self.length and will always
138+
write the proper magic/length.
139+
"""
140+
data: bytes = (
141+
# Extended header magic. This lets consumers detect whether the
142+
# header was inserted or not. Always use the proper magic value
143+
# (i.e., ignore self.magic) since there's no reason to create an
144+
# invalid header.
145+
self.EXPECTED_MAGIC
146+
# uint32_t: Size of this header. This makes it easier to add new
147+
# fields to this header in the future. Always use the proper size
148+
# (i.e., ignore self.length) since there's no reason to create an
149+
# invalid header.
150+
+ self.EXPECTED_LENGTH.to_bytes(4, byteorder=_HEADER_BYTEORDER)
151+
# uint64_t: Offset to the start of the flatbuffer data, in bytes.
152+
+ self.flatbuffer_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER)
153+
# uint64_t: Size of the serialized data in bytes.
154+
+ self.flatbuffer_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
155+
# uint64_t: Offset to the start of the first segment, or zero if
156+
# there are no segments.
157+
+ self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER)
158+
# uint64_t: Size of all the segment data, in bytes.
159+
+ self.segment_data_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
160+
)
161+
return data
162+
163+
164+
class FlatTensorSerializer(DataSerializer):
165+
"""A concrete implementation of the DataSerializer interface that
166+
serializes and deserializes data to/from the FlatTensor format.
167+
"""
168+
169+
def __init__(self, config: Optional[FlatTensorConfig] = None) -> None:
170+
"""FlatTensorConfig holds information required for serialization,
171+
eg. alignment.
172+
"""
173+
if config is None:
174+
self.config = FlatTensorConfig()
175+
else:
176+
self.config = config
177+
178+
def serialize(
179+
self,
180+
data: DataPayload,
181+
) -> Cord:
182+
"""Serializes a list of tensor metadata and tensors into a blob."""
183+
184+
flat_tensor_metadata: List[TensorMetadata] = []
185+
flat_tensor_data: Cord = Cord()
186+
187+
# {idx, offset}
188+
saved_offsets: Dict[int, int] = {}
189+
190+
for fqn, tensor_entry in data.fqn_to_tensor.items():
191+
assert tensor_entry.layout is not None
192+
# Check index into the tensor buffers is valid.
193+
assert tensor_entry.buffer_index < len(
194+
data.buffers
195+
), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(data.buffers)}."
196+
197+
# Check if the tensor has already been appended to the flat_tensor_data.
198+
offset = saved_offsets.get(tensor_entry.buffer_index, -1)
199+
if offset == -1:
200+
if len(flat_tensor_data) > 0:
201+
# Add padding to round off the previous tensor offset.
202+
pad_length = padding_required(
203+
len(flat_tensor_data), self.config.tensor_alignment
204+
)
205+
flat_tensor_data.append(b"\x00" * pad_length)
206+
# Add to saved offsets.
207+
offset = len(flat_tensor_data)
208+
saved_offsets[tensor_entry.buffer_index] = offset
209+
# Append to flat_tensor_data at the offset.
210+
flat_tensor_data.append(data.buffers[tensor_entry.buffer_index])
211+
212+
flat_tensor_metadata.append(
213+
TensorMetadata(
214+
fully_qualified_name=fqn,
215+
scalar_type=tensor_entry.layout.scalar_type,
216+
sizes=tensor_entry.layout.sizes,
217+
dim_order=tensor_entry.layout.dim_order,
218+
segment_index=0,
219+
offset=offset,
220+
)
221+
)
222+
223+
# Pad flat_tensor_data to segment alignment.
224+
segment_pad_length = padding_required(
225+
len(flat_tensor_data), self.config.segment_alignment
226+
)
227+
if segment_pad_length > 0:
228+
flat_tensor_data.append(b"\x00" * segment_pad_length)
229+
230+
# Create FlatTensor, which describes of the contents of the file and
231+
# points to all the data segments. It will be serialized to flatbuffer.
232+
flat_tensor = FlatTensor(
233+
version=0,
234+
tensor_alignment=self.config.tensor_alignment,
235+
tensors=flat_tensor_metadata,
236+
segments=[DataSegment(offset=0, size=len(flat_tensor_data))],
237+
)
238+
239+
flatbuffer_payload = _convert_to_flatbuffer(flat_tensor)
240+
padded_flatbuffer_length: int = aligned_size(
241+
input_size=len(flatbuffer_payload),
242+
alignment=self.config.tensor_alignment,
243+
)
244+
245+
padded_header_length: int = aligned_size(
246+
input_size=FlatTensorHeader.EXPECTED_LENGTH,
247+
alignment=self.config.tensor_alignment,
248+
)
249+
250+
segment_base_offset = aligned_size(
251+
padded_flatbuffer_length + padded_header_length,
252+
self.config.segment_alignment,
253+
)
254+
255+
# Create FlatTensorHeader, which stores the offsets and sizes of the
256+
# FlatTensor flatbuffer and the segment data.
257+
header_data: bytes = FlatTensorHeader(
258+
flatbuffer_offset=padded_header_length,
259+
flatbuffer_size=len(flatbuffer_payload),
260+
segment_base_offset=segment_base_offset,
261+
segment_data_size=len(flat_tensor_data),
262+
).to_bytes()
263+
264+
# Pad header and payload to segment alignment.
265+
header_data = pad_to(header_data, padded_header_length)
266+
flatbuffer_payload.append(
267+
b"\x00" * (padded_flatbuffer_length - len(flatbuffer_payload))
268+
)
269+
270+
# Place everything into one segment.
271+
payload = Cord()
272+
payload.append(header_data)
273+
payload.append(flatbuffer_payload)
274+
payload.append(flat_tensor_data)
275+
276+
return payload
277+
278+
def deserialize(self, blob: Cord) -> DataPayload:
279+
"""
280+
Deserializes a flat_tensor blob into a list of tensor metadata and tensors.
281+
"""
282+
raise NotImplementedError("deserialize_data")

extension/flat_tensor/test/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
3+
oncall("executorch")
4+
5+
python_unittest(
6+
name = "serialize",
7+
srcs = [
8+
"test_serialize.py",
9+
],
10+
deps = [
11+
"//executorch/extension/flat_tensor/serialize:serialize",
12+
"//executorch/extension/flat_tensor/serialize:schema",
13+
],
14+
)

0 commit comments

Comments
 (0)