Skip to content

Commit ea0204b

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
MethodMeta
Summary: There is a growing class of usecases that want to be able to inspect meta data about methods without paying the full init cost. These classes provide a safe and cheap way to view this information Reviewed By: dbort Differential Revision: D48039273 fbshipit-source-id: 261e7ac7d7ea6b1cc63f52d54815fd30711059b8
1 parent d7d91a2 commit ea0204b

File tree

7 files changed

+543
-0
lines changed

7 files changed

+543
-0
lines changed

runtime/executor/method_meta.cpp

+189
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/error.h>
10+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
11+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12+
#include <executorch/runtime/core/result.h>
13+
#include <executorch/runtime/core/span.h>
14+
#include <executorch/runtime/core/tag.h>
15+
#include <executorch/runtime/executor/method_meta.h>
16+
#include <executorch/schema/program_generated.h>
17+
18+
namespace torch {
19+
namespace executor {
20+
21+
namespace {
22+
Result<Tag> get_tag(
23+
flatbuffers::Vector<flatbuffers::Offset<executorch_flatbuffer::EValue>>::
24+
return_type serialization_value,
25+
size_t index) {
26+
switch (serialization_value->val_type()) {
27+
case executorch_flatbuffer::KernelTypes::Int: {
28+
return Tag::Int;
29+
} break;
30+
case executorch_flatbuffer::KernelTypes::Double: {
31+
return Tag::Double;
32+
} break;
33+
case executorch_flatbuffer::KernelTypes::Bool: {
34+
return Tag::Bool;
35+
} break;
36+
case executorch_flatbuffer::KernelTypes::String: {
37+
return Tag::String;
38+
} break;
39+
case executorch_flatbuffer::KernelTypes::Tensor: {
40+
return Tag::Tensor;
41+
} break;
42+
default:
43+
ET_LOG(
44+
Error,
45+
"Invalid tag: %zu input: %zu",
46+
(size_t)serialization_value->val_type(),
47+
index);
48+
return Error::Internal;
49+
}
50+
}
51+
52+
size_t calculate_nbytes(
53+
Span<const int32_t> sizes,
54+
exec_aten::ScalarType scalar_type) {
55+
ssize_t n = 1;
56+
for (ssize_t i = 0; i < sizes.size(); i++) {
57+
n *= sizes[i];
58+
}
59+
return n * sizeof_scalar_type(scalar_type);
60+
}
61+
62+
} // namespace
63+
64+
TensorInfo::TensorInfo(
65+
Span<const int32_t> sizes,
66+
Span<const uint8_t> dim_order,
67+
exec_aten::ScalarType scalar_type) noexcept
68+
: sizes_(sizes),
69+
dim_order_(dim_order),
70+
scalar_type_(scalar_type),
71+
nbytes_(calculate_nbytes(sizes_, scalar_type_)) {}
72+
73+
Span<const int32_t> TensorInfo::sizes() const noexcept {
74+
return sizes_;
75+
}
76+
77+
Span<const uint8_t> TensorInfo::dim_order() const noexcept {
78+
return dim_order_;
79+
}
80+
81+
exec_aten::ScalarType TensorInfo::scalar_type() const noexcept {
82+
return scalar_type_;
83+
}
84+
85+
size_t TensorInfo::nbytes() const noexcept {
86+
return nbytes_;
87+
}
88+
89+
MethodMeta::MethodMeta(
90+
const executorch_flatbuffer::ExecutionPlan* s_plan) noexcept
91+
: s_plan_(s_plan) {}
92+
93+
const char* MethodMeta::name() const noexcept {
94+
return s_plan_->name()->c_str();
95+
}
96+
97+
size_t MethodMeta::num_inputs() const noexcept {
98+
return s_plan_->inputs()->size();
99+
}
100+
101+
Result<Tag> MethodMeta::input_tag(size_t index) const noexcept {
102+
auto num_inputs = this->num_inputs();
103+
ET_CHECK_OR_RETURN_ERROR(
104+
index >= 0 && index < num_inputs,
105+
InvalidArgument,
106+
"index %zu out of range. num_inputs: %zu",
107+
index,
108+
num_inputs);
109+
auto input_index = s_plan_->inputs()->Get(index);
110+
auto serialization_value = s_plan_->values()->Get(input_index);
111+
return get_tag(serialization_value, index);
112+
}
113+
114+
Result<TensorInfo> MethodMeta::input_tensor_meta(size_t index) const noexcept {
115+
auto tag = this->input_tag(index);
116+
if (!tag.ok()) {
117+
return tag.error();
118+
}
119+
ET_CHECK_OR_RETURN_ERROR(
120+
tag.get() == Tag::Tensor,
121+
InvalidArgument,
122+
"Tag: %zu input: %zu is not Tensor",
123+
(size_t)tag.get(),
124+
index);
125+
auto input_index = s_plan_->inputs()->Get(index);
126+
auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor();
127+
return TensorInfo(
128+
Span<const int32_t>(
129+
tensor_value->sizes()->data(), tensor_value->sizes()->size()),
130+
Span<const uint8_t>(
131+
tensor_value->dim_order()->data(), tensor_value->dim_order()->size()),
132+
static_cast<exec_aten::ScalarType>(tensor_value->scalar_type()));
133+
}
134+
135+
size_t MethodMeta::num_outputs() const noexcept {
136+
return s_plan_->outputs()->size();
137+
}
138+
139+
Result<Tag> MethodMeta::output_tag(size_t index) const noexcept {
140+
auto num_outputs = this->num_outputs();
141+
ET_CHECK_OR_RETURN_ERROR(
142+
index >= 0 && index < num_outputs,
143+
InvalidArgument,
144+
"index %zu out of range. num_outputs: %zu",
145+
index,
146+
num_outputs);
147+
auto input_index = s_plan_->outputs()->Get(index);
148+
auto serialization_value = s_plan_->values()->Get(input_index);
149+
return get_tag(serialization_value, index);
150+
}
151+
152+
Result<TensorInfo> MethodMeta::output_tensor_meta(size_t index) const noexcept {
153+
auto tag = this->output_tag(index);
154+
if (!tag.ok()) {
155+
return tag.error();
156+
}
157+
ET_CHECK_OR_RETURN_ERROR(
158+
tag.get() == Tag::Tensor,
159+
InvalidArgument,
160+
"Tag: %zu output: %zu is not Tensor",
161+
(size_t)tag.get(),
162+
index);
163+
auto input_index = s_plan_->outputs()->Get(index);
164+
auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor();
165+
return TensorInfo(
166+
Span<const int32_t>(
167+
tensor_value->sizes()->data(), tensor_value->sizes()->size()),
168+
Span<const uint8_t>(
169+
tensor_value->dim_order()->data(), tensor_value->dim_order()->size()),
170+
static_cast<exec_aten::ScalarType>(tensor_value->scalar_type()));
171+
}
172+
173+
size_t MethodMeta::num_non_const_buffers() const noexcept {
174+
return s_plan_->non_const_buffer_sizes()->size();
175+
}
176+
177+
Result<int64_t> MethodMeta::non_const_buffer_size(size_t index) const noexcept {
178+
auto num_buffers = this->num_non_const_buffers();
179+
ET_CHECK_OR_RETURN_ERROR(
180+
index >= 0 && index < num_buffers,
181+
InvalidArgument,
182+
"index %zu out of range. num_buffers: %zu",
183+
index,
184+
num_buffers);
185+
return s_plan_->non_const_buffer_sizes()->Get(index);
186+
}
187+
188+
} // namespace executor
189+
} // namespace torch

runtime/executor/method_meta.h

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/result.h>
13+
#include <executorch/runtime/core/span.h>
14+
#include <executorch/runtime/core/tag.h>
15+
16+
// Forward declare flatbuffer types. This is a public header and must not
17+
// include the generated flatbuffer header.
18+
namespace executorch_flatbuffer {
19+
struct ExecutionPlan;
20+
} // namespace executorch_flatbuffer
21+
22+
namespace torch {
23+
namespace executor {
24+
25+
/**
26+
* Metadata about a specific tensor of an Executorch Program.
27+
*
28+
* The program used to create the MethodMeta object that created this
29+
* TensorInfo must outlive this TensorInfo.
30+
*/
31+
class TensorInfo final {
32+
public:
33+
TensorInfo() = delete;
34+
TensorInfo(const TensorInfo&) = default;
35+
TensorInfo(TensorInfo&&) = default;
36+
TensorInfo& operator=(const TensorInfo&) = default;
37+
TensorInfo& operator=(TensorInfo&& other) = default;
38+
~TensorInfo() = default;
39+
40+
/**
41+
* Returns the sizes of the tensor.
42+
*/
43+
Span<const int32_t> sizes() const;
44+
45+
/**
46+
* Returns the dim order of the tensor.
47+
*/
48+
Span<const uint8_t> dim_order() const;
49+
50+
/**
51+
* Returns the scalar type of the input/output.
52+
*/
53+
exec_aten::ScalarType scalar_type() const;
54+
55+
/**
56+
* Returns the size of the tensor in bytes.
57+
*/
58+
size_t nbytes() const;
59+
60+
private:
61+
// Let MethodMeta create TensorInfo.
62+
friend class MethodMeta;
63+
64+
TensorInfo(
65+
Span<const int32_t> sizes,
66+
Span<const uint8_t> dim_order,
67+
exec_aten::ScalarType scalar_type);
68+
69+
/**
70+
* The sizes of the tensor.
71+
*
72+
* NOTE: References data from the Program, so the Program must outlive the
73+
* TensorInfo.
74+
*/
75+
Span<const int32_t> sizes_;
76+
77+
/**
78+
* The dim order of the tensor.
79+
*
80+
* NOTE: References data from the Program, so the Program must outlive the
81+
* TensorInfo.
82+
*/
83+
Span<const uint8_t> dim_order_;
84+
85+
/// The scalar type of the tensor.
86+
exec_aten::ScalarType scalar_type_;
87+
88+
/// The size in bytes of the tensor.
89+
size_t nbytes_;
90+
};
91+
92+
/**
93+
* Describes a a method in an Executorch program.
94+
*
95+
* The program used to create a MethodMeta object must outlive the MethodMeta.
96+
* It is separate from Method so that this information can be accessed without
97+
* paying the initialization cost of loading the full Method.
98+
*/
99+
class MethodMeta final {
100+
public:
101+
MethodMeta() = delete;
102+
MethodMeta(const MethodMeta&) = default;
103+
MethodMeta(MethodMeta&&) = default;
104+
MethodMeta& operator=(const MethodMeta&) = default;
105+
MethodMeta& operator=(MethodMeta&& other) = default;
106+
~MethodMeta() = default;
107+
108+
/**
109+
* Get the name of this method.
110+
*
111+
* @returns The method name.
112+
*/
113+
const char* name() const;
114+
115+
/**
116+
* Get the number of inputs to this method.
117+
*
118+
* @returns The number of inputs.
119+
*/
120+
size_t num_inputs() const;
121+
122+
/**
123+
* Get the tag of the specified input.
124+
*
125+
* @param[in] index The index of the input to look up.
126+
* @returns The tag of input, can only be [Tensor, Int, Bool, Double, String].
127+
*/
128+
Result<Tag> input_tag(size_t index) const;
129+
130+
/**
131+
* Get metadata about the specified input.
132+
*
133+
* @param[in] index The index of the input to look up.
134+
* @returns The metadata on success, or an error on failure. Only valid for
135+
* tag::Tensor
136+
*/
137+
Result<TensorInfo> input_tensor_meta(size_t index) const;
138+
139+
/**
140+
* Get the number of outputs to this method.
141+
*
142+
* @returns The number of outputs.
143+
*/
144+
size_t num_outputs() const;
145+
146+
/**
147+
* Get the tag of the specified output.
148+
*
149+
* @param[in] index The index of the output to look up.
150+
* @returns The tag of output, can only be [Tensor, Int, Bool, Double,
151+
* String].
152+
*/
153+
Result<Tag> output_tag(size_t index) const;
154+
155+
/**
156+
* Get metadata about the specified output.
157+
*
158+
* @param[in] index The index of the output to look up.
159+
* @returns The metadata on success, or an error on failure. Only valid for
160+
* tag::Tensor
161+
*/
162+
Result<TensorInfo> output_tensor_meta(size_t index) const;
163+
164+
/**
165+
* Get the number of non-constant buffers this method requires.
166+
*
167+
* @returns The number of non-constant buffers.
168+
*/
169+
size_t num_non_const_buffers() const;
170+
171+
/**
172+
* Get the size in bytes of the specified non-constant buffer.
173+
*
174+
* @param[in] index The index of the buffer to look up.
175+
* @returns The size in bytes on success, or an error on failure.
176+
*/
177+
Result<int64_t> non_const_buffer_size(size_t index) const;
178+
179+
private:
180+
// Let Program create MethodMeta.
181+
friend class Program;
182+
183+
explicit MethodMeta(const executorch_flatbuffer::ExecutionPlan* s_plan);
184+
185+
/// Source of truth for method information
186+
const executorch_flatbuffer::ExecutionPlan* s_plan_;
187+
};
188+
189+
} // namespace executor
190+
} // namespace torch

runtime/executor/program.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,18 @@ Result<Method> Program::load_method(
206206
return Error::InvalidArgument;
207207
}
208208

209+
Result<MethodMeta> Program::method_meta(const char* method_name) const {
210+
EXECUTORCH_SCOPE_PROF("Program::method_meta");
211+
auto execution_plans = internal_program_->execution_plan();
212+
for (size_t i = 0; i < execution_plans->size(); i++) {
213+
auto serialization_plan = execution_plans->GetMutableObject(i);
214+
if (std::strcmp(serialization_plan->name()->c_str(), method_name) == 0) {
215+
return MethodMeta(serialization_plan);
216+
}
217+
}
218+
return Error::InvalidArgument;
219+
}
220+
209221
const void* Program::get_constant_buffer_data(size_t buffer_idx) const {
210222
ET_CHECK(is_valid());
211223
auto internal_program =

0 commit comments

Comments
 (0)