|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <executorch/runtime/core/error.h> |
| 10 | +#include <executorch/runtime/core/exec_aten/exec_aten.h> |
| 11 | +#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> |
| 12 | +#include <executorch/runtime/core/result.h> |
| 13 | +#include <executorch/runtime/core/span.h> |
| 14 | +#include <executorch/runtime/core/tag.h> |
| 15 | +#include <executorch/runtime/executor/method_meta.h> |
| 16 | +#include <executorch/schema/program_generated.h> |
| 17 | + |
| 18 | +namespace torch { |
| 19 | +namespace executor { |
| 20 | + |
| 21 | +namespace { |
| 22 | +Result<Tag> get_tag( |
| 23 | + flatbuffers::Vector<flatbuffers::Offset<executorch_flatbuffer::EValue>>:: |
| 24 | + return_type serialization_value, |
| 25 | + size_t index) { |
| 26 | + switch (serialization_value->val_type()) { |
| 27 | + case executorch_flatbuffer::KernelTypes::Int: { |
| 28 | + return Tag::Int; |
| 29 | + } break; |
| 30 | + case executorch_flatbuffer::KernelTypes::Double: { |
| 31 | + return Tag::Double; |
| 32 | + } break; |
| 33 | + case executorch_flatbuffer::KernelTypes::Bool: { |
| 34 | + return Tag::Bool; |
| 35 | + } break; |
| 36 | + case executorch_flatbuffer::KernelTypes::String: { |
| 37 | + return Tag::String; |
| 38 | + } break; |
| 39 | + case executorch_flatbuffer::KernelTypes::Tensor: { |
| 40 | + return Tag::Tensor; |
| 41 | + } break; |
| 42 | + default: |
| 43 | + ET_LOG( |
| 44 | + Error, |
| 45 | + "Invalid tag: %zu input: %zu", |
| 46 | + (size_t)serialization_value->val_type(), |
| 47 | + index); |
| 48 | + return Error::Internal; |
| 49 | + } |
| 50 | +} |
| 51 | + |
| 52 | +size_t calculate_nbytes( |
| 53 | + Span<const int32_t> sizes, |
| 54 | + exec_aten::ScalarType scalar_type) { |
| 55 | + ssize_t n = 1; |
| 56 | + for (ssize_t i = 0; i < sizes.size(); i++) { |
| 57 | + n *= sizes[i]; |
| 58 | + } |
| 59 | + return n * sizeof_scalar_type(scalar_type); |
| 60 | +} |
| 61 | + |
| 62 | +} // namespace |
| 63 | + |
| 64 | +TensorInfo::TensorInfo( |
| 65 | + Span<const int32_t> sizes, |
| 66 | + Span<const uint8_t> dim_order, |
| 67 | + exec_aten::ScalarType scalar_type) noexcept |
| 68 | + : sizes_(sizes), |
| 69 | + dim_order_(dim_order), |
| 70 | + scalar_type_(scalar_type), |
| 71 | + nbytes_(calculate_nbytes(sizes_, scalar_type_)) {} |
| 72 | + |
| 73 | +Span<const int32_t> TensorInfo::sizes() const noexcept { |
| 74 | + return sizes_; |
| 75 | +} |
| 76 | + |
| 77 | +Span<const uint8_t> TensorInfo::dim_order() const noexcept { |
| 78 | + return dim_order_; |
| 79 | +} |
| 80 | + |
| 81 | +exec_aten::ScalarType TensorInfo::scalar_type() const noexcept { |
| 82 | + return scalar_type_; |
| 83 | +} |
| 84 | + |
| 85 | +size_t TensorInfo::nbytes() const noexcept { |
| 86 | + return nbytes_; |
| 87 | +} |
| 88 | + |
| 89 | +MethodMeta::MethodMeta( |
| 90 | + const executorch_flatbuffer::ExecutionPlan* s_plan) noexcept |
| 91 | + : s_plan_(s_plan) {} |
| 92 | + |
| 93 | +const char* MethodMeta::name() const noexcept { |
| 94 | + return s_plan_->name()->c_str(); |
| 95 | +} |
| 96 | + |
| 97 | +size_t MethodMeta::num_inputs() const noexcept { |
| 98 | + return s_plan_->inputs()->size(); |
| 99 | +} |
| 100 | + |
| 101 | +Result<Tag> MethodMeta::input_tag(size_t index) const noexcept { |
| 102 | + auto num_inputs = this->num_inputs(); |
| 103 | + ET_CHECK_OR_RETURN_ERROR( |
| 104 | + index >= 0 && index < num_inputs, |
| 105 | + InvalidArgument, |
| 106 | + "index %zu out of range. num_inputs: %zu", |
| 107 | + index, |
| 108 | + num_inputs); |
| 109 | + auto input_index = s_plan_->inputs()->Get(index); |
| 110 | + auto serialization_value = s_plan_->values()->Get(input_index); |
| 111 | + return get_tag(serialization_value, index); |
| 112 | +} |
| 113 | + |
| 114 | +Result<TensorInfo> MethodMeta::input_tensor_meta(size_t index) const noexcept { |
| 115 | + auto tag = this->input_tag(index); |
| 116 | + if (!tag.ok()) { |
| 117 | + return tag.error(); |
| 118 | + } |
| 119 | + ET_CHECK_OR_RETURN_ERROR( |
| 120 | + tag.get() == Tag::Tensor, |
| 121 | + InvalidArgument, |
| 122 | + "Tag: %zu input: %zu is not Tensor", |
| 123 | + (size_t)tag.get(), |
| 124 | + index); |
| 125 | + auto input_index = s_plan_->inputs()->Get(index); |
| 126 | + auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor(); |
| 127 | + return TensorInfo( |
| 128 | + Span<const int32_t>( |
| 129 | + tensor_value->sizes()->data(), tensor_value->sizes()->size()), |
| 130 | + Span<const uint8_t>( |
| 131 | + tensor_value->dim_order()->data(), tensor_value->dim_order()->size()), |
| 132 | + static_cast<exec_aten::ScalarType>(tensor_value->scalar_type())); |
| 133 | +} |
| 134 | + |
| 135 | +size_t MethodMeta::num_outputs() const noexcept { |
| 136 | + return s_plan_->outputs()->size(); |
| 137 | +} |
| 138 | + |
| 139 | +Result<Tag> MethodMeta::output_tag(size_t index) const noexcept { |
| 140 | + auto num_outputs = this->num_outputs(); |
| 141 | + ET_CHECK_OR_RETURN_ERROR( |
| 142 | + index >= 0 && index < num_outputs, |
| 143 | + InvalidArgument, |
| 144 | + "index %zu out of range. num_outputs: %zu", |
| 145 | + index, |
| 146 | + num_outputs); |
| 147 | + auto input_index = s_plan_->outputs()->Get(index); |
| 148 | + auto serialization_value = s_plan_->values()->Get(input_index); |
| 149 | + return get_tag(serialization_value, index); |
| 150 | +} |
| 151 | + |
| 152 | +Result<TensorInfo> MethodMeta::output_tensor_meta(size_t index) const noexcept { |
| 153 | + auto tag = this->output_tag(index); |
| 154 | + if (!tag.ok()) { |
| 155 | + return tag.error(); |
| 156 | + } |
| 157 | + ET_CHECK_OR_RETURN_ERROR( |
| 158 | + tag.get() == Tag::Tensor, |
| 159 | + InvalidArgument, |
| 160 | + "Tag: %zu output: %zu is not Tensor", |
| 161 | + (size_t)tag.get(), |
| 162 | + index); |
| 163 | + auto input_index = s_plan_->outputs()->Get(index); |
| 164 | + auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor(); |
| 165 | + return TensorInfo( |
| 166 | + Span<const int32_t>( |
| 167 | + tensor_value->sizes()->data(), tensor_value->sizes()->size()), |
| 168 | + Span<const uint8_t>( |
| 169 | + tensor_value->dim_order()->data(), tensor_value->dim_order()->size()), |
| 170 | + static_cast<exec_aten::ScalarType>(tensor_value->scalar_type())); |
| 171 | +} |
| 172 | + |
| 173 | +size_t MethodMeta::num_non_const_buffers() const noexcept { |
| 174 | + return s_plan_->non_const_buffer_sizes()->size(); |
| 175 | +} |
| 176 | + |
| 177 | +Result<int64_t> MethodMeta::non_const_buffer_size(size_t index) const noexcept { |
| 178 | + auto num_buffers = this->num_non_const_buffers(); |
| 179 | + ET_CHECK_OR_RETURN_ERROR( |
| 180 | + index >= 0 && index < num_buffers, |
| 181 | + InvalidArgument, |
| 182 | + "index %zu out of range. num_buffers: %zu", |
| 183 | + index, |
| 184 | + num_buffers); |
| 185 | + return s_plan_->non_const_buffer_sizes()->Get(index); |
| 186 | +} |
| 187 | + |
| 188 | +} // namespace executor |
| 189 | +} // namespace torch |
0 commit comments