@@ -19,6 +19,8 @@ namespace executorch {
19
19
namespace runtime {
20
20
namespace deserialization {
21
21
22
+ using executorch::aten::ScalarType;
23
+ using executorch::runtime::TensorLayout;
22
24
// Provides access to private Program methods.
23
25
class TensorParser final {
24
26
public:
@@ -113,7 +115,8 @@ ET_NODISCARD Result<void*> getTensorDataPtr(
113
115
const executorch_flatbuffer::Tensor* s_tensor,
114
116
const Program* program,
115
117
size_t nbytes,
116
- HierarchicalAllocator* allocator) {
118
+ HierarchicalAllocator* allocator,
119
+ const NamedDataMap* named_data_map) {
117
120
auto data_buffer_idx = s_tensor->data_buffer_idx ();
118
121
const executorch_flatbuffer::AllocationDetails* allocation_info =
119
122
s_tensor->allocation_info ();
@@ -131,9 +134,103 @@ ET_NODISCARD Result<void*> getTensorDataPtr(
131
134
return err;
132
135
}
133
136
return planned_ptr;
137
+ }
138
+
139
+ // External tensors.
140
+ else if (
141
+ s_tensor->extra_tensor_info () != nullptr &&
142
+ s_tensor->extra_tensor_info ()->location () ==
143
+ executorch_flatbuffer::TensorDataLocation::EXTERNAL) {
144
+ // Check that fqn is not null.
145
+ ET_CHECK_OR_RETURN_ERROR (
146
+ s_tensor->extra_tensor_info ()->fully_qualified_name () != nullptr ,
147
+ InvalidExternalData,
148
+ " Fully qualified name of external tensor is null" );
149
+ // Look up tensor in named data map.
150
+ Result<const TensorLayout> tensor_layout_res = named_data_map->get_metadata (
151
+ s_tensor->extra_tensor_info ()->fully_qualified_name ()->c_str ());
152
+ if (!tensor_layout_res.ok ()) {
153
+ return tensor_layout_res.error ();
154
+ }
155
+ const TensorLayout& tensor_layout = tensor_layout_res.get ();
156
+
157
+ // Compatibility checking.
158
+ ET_CHECK_OR_RETURN_ERROR (
159
+ static_cast <ScalarType>(s_tensor->scalar_type ()) ==
160
+ tensor_layout.scalar_type (),
161
+ InvalidExternalData,
162
+ " Scalar type mismatch. Expected %hhd, got %hhd." ,
163
+ static_cast <int8_t >(s_tensor->scalar_type ()),
164
+ static_cast <int8_t >(tensor_layout.scalar_type ()));
165
+ ET_CHECK_OR_RETURN_ERROR (
166
+ nbytes == tensor_layout.nbytes (),
167
+ InvalidExternalData,
168
+ " Nbytes mismatch. Expected %zu, got %zu." ,
169
+ nbytes,
170
+ tensor_layout.nbytes ());
171
+ int dim = s_tensor->sizes ()->size ();
172
+ ET_CHECK_OR_RETURN_ERROR (
173
+ dim == tensor_layout.sizes ().size (),
174
+ InvalidExternalData,
175
+ " Dim mismatch. Expected %d, got %zu." ,
176
+ dim,
177
+ tensor_layout.sizes ().size ());
178
+ for (int i = 0 ; i < dim; i++) {
179
+ ET_CHECK_OR_RETURN_ERROR (
180
+ s_tensor->sizes ()->Get (i) == tensor_layout.sizes ()[i],
181
+ InvalidExternalData,
182
+ " Sizes mismatch. Expected %d, got %d for size at index %d." ,
183
+ s_tensor->sizes ()->Get (i),
184
+ tensor_layout.sizes ()[i],
185
+ i);
186
+ ET_CHECK_OR_RETURN_ERROR (
187
+ s_tensor->dim_order ()->Get (i) == tensor_layout.dim_order ()[i],
188
+ InvalidExternalData,
189
+ " Dim order mismatch. Expected %d, got %d for dim at index %d." ,
190
+ s_tensor->dim_order ()->Get (i),
191
+ tensor_layout.dim_order ()[i],
192
+ i);
193
+ }
194
+
195
+ // Constant value.
196
+ if (allocation_info == nullptr ) {
197
+ Result<FreeableBuffer> data_res = named_data_map->get_data (
198
+ s_tensor->extra_tensor_info ()->fully_qualified_name ()->c_str ());
199
+ if (!data_res.ok ()) {
200
+ return data_res.error ();
201
+ }
202
+ // The const_cast is 'ok' here because program and runtime should
203
+ // guarantee that this data is never modified. Temporary until runtime
204
+ // takes ownership of FreeableBuffers in TODO(T214294528).
205
+ return const_cast <void *>(data_res.get ().data ());
206
+ }
207
+
208
+ // Mutable value.
209
+ else {
210
+ // Call load_into.
211
+ auto planned_ptr = getMemPlannedPtr (allocation_info, nbytes, allocator);
212
+ if (!planned_ptr.ok ()) {
213
+ return planned_ptr.error ();
214
+ }
215
+ auto size = named_data_map->load_data_into (
216
+ s_tensor->extra_tensor_info ()->fully_qualified_name ()->c_str (),
217
+ planned_ptr.get (),
218
+ nbytes);
219
+ if (size.error () != Error::Ok) {
220
+ return size.error ();
221
+ }
222
+ ET_CHECK_OR_RETURN_ERROR (
223
+ size.get () == nbytes,
224
+ InvalidExternalData,
225
+ " Expected to load %zu bytes, actually loaded %u bytes" ,
226
+ nbytes,
227
+ static_cast <unsigned int >(size.get ()));
228
+ return planned_ptr;
229
+ }
230
+ }
134
231
135
- // Constant
136
- } else if (data_buffer_idx > 0 && allocation_info == nullptr ) {
232
+ // Constant, stored in PTE file.
233
+ else if (data_buffer_idx > 0 && allocation_info == nullptr ) {
137
234
auto const_data =
138
235
program->get_constant_buffer_data (data_buffer_idx, nbytes);
139
236
if (!const_data.ok ()) {
0 commit comments