8
8
9
9
import copy
10
10
import json
11
+ import math
11
12
import re
12
13
13
14
from dataclasses import dataclass
14
- from typing import ClassVar , List , Literal , Optional , Tuple
15
+ from typing import ClassVar , Dict , List , Literal , Optional , Tuple
15
16
16
17
from executorch .exir ._serialize ._cord import Cord
17
18
from executorch .exir ._serialize ._dataclass import _DataclassEncoder , _json_to_dataclass
20
21
_program_flatbuffer_to_json ,
21
22
_program_json_to_flatbuffer ,
22
23
)
24
+ from executorch .exir ._serialize ._named_data_store import (
25
+ BufferEntry ,
26
+ NamedDataStoreOutput ,
27
+ )
23
28
24
29
from executorch .exir ._serialize .padding import aligned_size , pad_to , padding_required
25
30
29
34
Buffer ,
30
35
DataLocation ,
31
36
DataSegment ,
37
+ NamedData ,
32
38
Program ,
33
39
SubsegmentOffsets ,
34
40
)
41
47
_HEADER_BYTEORDER : Literal ["little" ] = "little"
42
48
43
49
50
+ @dataclass
51
+ class AlignedData :
52
+ """
53
+ Holds data that should be aligned, for serialization.
54
+
55
+ Attributes:
56
+ data: The data to serialize, as a cord.
57
+ alignment: The alignment required for the data.
58
+ """
59
+
60
+ data : Cord
61
+ alignment : int
62
+
63
+ def __init__ (self , data : Cord , alignment : Optional [int ] = None ) -> None :
64
+ self .data = data
65
+ self .alignment = alignment or 1
66
+
67
+
44
68
def _program_to_json (program : Program ) -> str :
45
69
"""Returns the JSON representation of the given Program."""
46
70
return json .dumps (program , cls = _DataclassEncoder )
@@ -213,7 +237,7 @@ def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
213
237
214
238
def _extract_delegate_segments (
215
239
program : Program ,
216
- segments : List [Cord ],
240
+ segments : List [AlignedData ],
217
241
) -> None :
218
242
"""Extracts the delegate segments inlined in the program into a list of buffers.
219
243
The program is modified in-place to remove the delegate data.
@@ -253,7 +277,7 @@ def _extract_delegate_segments(
253
277
segment_index = segment_index_map .get (inline .data )
254
278
if segment_index is None :
255
279
segment_index = len (segments )
256
- segments .append (Cord (inline .data ))
280
+ segments .append (AlignedData ( Cord (inline .data ) ))
257
281
segment_index_map [inline .data ] = segment_index
258
282
delegate .processed = BackendDelegateDataReference (
259
283
location = DataLocation .SEGMENT ,
@@ -316,6 +340,44 @@ def _extract_constant_segment(
316
340
return constant_segment_data , constant_segment_offsets
317
341
318
342
343
+ def _extract_named_data (
344
+ program : Program ,
345
+ segments : List [AlignedData ],
346
+ buffers : List [BufferEntry ],
347
+ name_to_buffer_idx : Dict [str , int ],
348
+ ) -> None :
349
+ """Modifies the program in-place to add references to the named data
350
+ segments.
351
+
352
+ Args:
353
+ program: The program to extract segments from. Modified in-place.
354
+ segments: A list of buffers to append extracted segments to. Modified in-place.
355
+ buffers: A list of unique buffers and the information required to
356
+ serialize them. Not modified.
357
+ name_to_buffer_idx: A map from the name of a blob to the index in buffers.
358
+ Not modified.
359
+ """
360
+ if program .named_data is not None and len (program .named_data ) > 0 :
361
+ raise ValueError ("Program already has named data." )
362
+
363
+ # Map from buffer_idx to segment_idx.
364
+ segment_index_map : Dict [int , int ] = {}
365
+
366
+ named_data : List [NamedData ] = []
367
+ for name , buffer_idx in name_to_buffer_idx .items ():
368
+ segment_index = segment_index_map .get (buffer_idx , None )
369
+ if segment_index is None :
370
+ segment_index = len (segments )
371
+ segment_index_map [buffer_idx ] = segment_index
372
+ segments .append (
373
+ AlignedData (
374
+ Cord (buffers [buffer_idx ].buffer ), buffers [buffer_idx ].alignment
375
+ )
376
+ )
377
+ named_data .append (NamedData (key = name , segment_index = segment_index ))
378
+ program .named_data = named_data
379
+
380
+
319
381
def serialize_pte_binary (
320
382
program : Program ,
321
383
* ,
@@ -324,6 +386,7 @@ def serialize_pte_binary(
324
386
segment_alignment : int = 128 ,
325
387
constant_tensor_alignment : Optional [int ] = None ,
326
388
delegate_alignment : Optional [int ] = None ,
389
+ named_data : Optional [NamedDataStoreOutput ] = None ,
327
390
) -> Cord :
328
391
"""Returns the runtime binary representation of the given Program.
329
392
@@ -343,6 +406,8 @@ def serialize_pte_binary(
343
406
delegate_alignment: If provided, the minimum alignment of delegate data
344
407
in the program. Must be a power of 2. If not provided, uses the
345
408
value in the schema file.
409
+ named_data: If provided, named blobs to be stored in segments
410
+ after the PTE file.
346
411
Returns:
347
412
The serialized form of the Program, ready for execution by the runtime.
348
413
"""
@@ -355,8 +420,9 @@ def serialize_pte_binary(
355
420
# copy, reusing the actual data blobs.
356
421
program = copy .deepcopy (program )
357
422
358
- # Store extracted segment data; this may be constant data or delegate data.
359
- segments : List [Cord ] = []
423
+ # Store extracted segment data, with any buffer-specific alignment.
424
+ # This may be constant data, delegate data or named data.
425
+ segments : List [AlignedData ] = []
360
426
361
427
constant_segment_data , constant_segment_offsets = _extract_constant_segment (
362
428
program .constant_buffer , tensor_alignment = constant_tensor_alignment
@@ -374,7 +440,7 @@ def serialize_pte_binary(
374
440
# Clear the constant buffer, as constant data will be stored in segments.
375
441
program .constant_buffer = []
376
442
# Add to the aggregate segments cord.
377
- segments .append (constant_segment_data )
443
+ segments .append (AlignedData ( constant_segment_data ) )
378
444
379
445
if mutable_data is not None :
380
446
mutable_segment_data , mutable_segment_offsets = _extract_constant_segment (
@@ -389,31 +455,34 @@ def serialize_pte_binary(
389
455
),
390
456
]
391
457
# Add to the aggregate segments cord.
392
- segments .append (mutable_segment_data )
458
+ segments .append (AlignedData ( mutable_segment_data ) )
393
459
394
460
if extract_delegate_segments :
395
461
_extract_delegate_segments (program , segments )
462
+ if named_data is not None :
463
+ _extract_named_data (program , segments , named_data .buffers , named_data .pte_data )
396
464
397
465
# Append all segments into a single Cord, adding any necessary padding to ensure that
398
466
# each segment begins at the required alignment.
399
467
# Update program.segments with the offsets to each segment.
400
468
segments_data = Cord ()
401
- for data in segments :
469
+ for segment in segments :
402
470
prev_end = (
403
471
(program .segments [- 1 ].offset + program .segments [- 1 ].size )
404
472
if program .segments
405
473
else 0
406
474
)
475
+ alignment = math .lcm (segment_alignment , segment .alignment )
407
476
program .segments .append (
408
477
DataSegment (
409
- offset = aligned_size (prev_end , segment_alignment ), size = len (data )
478
+ offset = aligned_size (prev_end , alignment ), size = len (segment . data )
410
479
)
411
480
)
412
481
# Add to aggregate segments cord with padding.
413
- padding_length = padding_required (len (segments_data ), segment_alignment )
482
+ padding_length = padding_required (len (segments_data ), alignment )
414
483
if padding_length > 0 :
415
484
segments_data .append (b"\x00 " * padding_length )
416
- segments_data .append (data )
485
+ segments_data .append (segment . data )
417
486
418
487
# Convert to a standard flatbuffer binary.
419
488
result : _FlatbufferResult = _program_json_to_flatbuffer (
0 commit comments