Skip to content

Commit 29b3792

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Extend support for scalars and scalar lists in Value class (#2271)
Summary: ## Context This changeset enables serialization and execution of Operators with arbitrary function signatures. Previously, only operators with a very specific schema were supported (2 inputs, 1 output). This is achieved by extending the `Value` class (which is essentially a tagged union) to support all necessary types. All objects needed to execute an operator are now serialized/deserialized as a tagged union. This changeset also refactors `VulkanBackend.cpp` by introducing `GraphBuilder` which makes constructing a `ComputeGraph` from a serialized flatbuffer much clearer. Differential Revision: D54561567
1 parent bcba739 commit 29b3792

14 files changed

+677
-329
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

+167-155
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@
2323
#include <cstdlib> /* strtol */
2424
#include <memory>
2525
#include <type_traits>
26+
#include <vector>
2627

2728
namespace torch {
2829
namespace executor {
2930
namespace vulkan {
3031
namespace {
3132

33+
using namespace at::native::vulkan;
34+
3235
// Flatbuffer types
3336
using VkGraphPtr = const vkgraph::VkGraph*;
3437
using OpCallPtr = const vkgraph::OperatorCall*;
@@ -51,102 +54,193 @@ const uint8_t* getConstantDataPtr(
5154
return constant_data + constant_bytes->offset();
5255
}
5356

54-
using namespace at::native::vulkan;
57+
api::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
58+
switch (vk_datatype) {
59+
case (vkgraph::VkDataType::fp32): {
60+
return api::kFloat;
61+
}
62+
}
63+
}
64+
65+
GraphConfig generate_config() {
66+
const uint32_t submit_frequency = UINT32_MAX;
67+
68+
const api::CommandPoolConfig cmd_config{
69+
4u, // cmdPoolInitialSize
70+
2u, // cmdPoolBatchSize
71+
};
72+
73+
const api::DescriptorPoolConfig descriptor_pool_config{
74+
1024u, // descriptorPoolMaxSets
75+
1024u, // descriptorUniformBufferCount
76+
1024u, // descriptorStorageBufferCount
77+
1024u, // descriptorCombinedSamplerCount
78+
1024u, // descriptorStorageImageCount
79+
32u, // descriptorPileSizes
80+
};
81+
82+
const api::QueryPoolConfig query_pool_config{};
83+
84+
const api::ContextConfig context_config{
85+
submit_frequency, // cmdSubmitFrequency
86+
cmd_config, // cmdPoolConfig
87+
descriptor_pool_config, // descriptorPoolConfig
88+
query_pool_config, // queryPoolConfig
89+
};
90+
91+
const GraphConfig graph_config{
92+
context_config,
93+
};
94+
95+
return graph_config;
96+
}
97+
98+
class GraphBuilder {
99+
ComputeGraph* compute_graph_;
100+
VkGraphPtr flatbuffer_;
101+
const uint8_t* constant_data_;
102+
103+
std::unordered_map<uint32_t, ValueRef> ref_mapping_;
55104

56-
class VulkanBackend final : public PyTorchBackendInterface {
57105
public:
58-
~VulkanBackend() override = default;
106+
explicit GraphBuilder(
107+
ComputeGraph* compute_graph,
108+
VkGraphPtr flatbuffer,
109+
const uint8_t* constant_data)
110+
: compute_graph_(compute_graph),
111+
flatbuffer_(flatbuffer),
112+
constant_data_(constant_data),
113+
ref_mapping_() {}
114+
115+
bool fb_id_exists(const uint32_t fb_id) {
116+
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
117+
ref_mapping_.find(fb_id);
59118

60-
bool is_available() const override {
61-
return true;
119+
return found_ref != ref_mapping_.end();
62120
}
63121

64-
api::ScalarType get_scalar_type(
65-
const vkgraph::VkDataType& vk_datatype) const {
66-
switch (vk_datatype) {
67-
case (vkgraph::VkDataType::fp32): {
68-
return api::kFloat;
69-
}
70-
}
122+
ValueRef get_fb_id_valueref(const uint32_t fb_id) {
123+
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
124+
ref_mapping_.find(fb_id);
125+
126+
ET_CHECK_MSG(
127+
found_ref != ref_mapping_.end(),
128+
"Trying to extract a value that hasn't yet been added to the graph.");
129+
130+
return found_ref->second;
71131
}
72132

73-
ValueRef get_value_ref(
74-
const uint32_t value_id,
75-
VkGraphPtr flatbuffer_graph,
76-
ComputeGraph* compute_graph,
77-
std::unordered_map<uint32_t, ValueRef>& ref_mapping,
78-
VkValuesVector value_mapping,
79-
const uint8_t* constant_data) const {
80-
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
81-
ref_mapping.find(value_id);
133+
void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) {
134+
const api::ScalarType& dtype = get_scalar_type(tensor_fb->datatype());
135+
136+
UIntVector dims_fb = tensor_fb->dims();
137+
const std::vector<int64_t> dims_vector(dims_fb->cbegin(), dims_fb->cend());
82138

83-
if (found_ref != ref_mapping.end()) {
84-
return found_ref->second;
139+
ValueRef ref;
140+
if (tensor_fb->constant_id() >= 0) {
141+
const uint8_t* tensor_data = getConstantDataPtr(
142+
flatbuffer_, tensor_fb->constant_id(), constant_data_);
143+
144+
ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data);
145+
} else {
146+
ref = compute_graph_->add_tensor(
147+
dims_vector, dtype, tensor_fb->mem_obj_id());
85148
}
86149

87-
VkValuePtr vk_value = value_mapping->Get(value_id);
88-
VkTensorPtr vk_tensor = vk_value->value();
150+
ref_mapping_[fb_id] = ref;
151+
}
152+
153+
template <typename T>
154+
typename std::enable_if<is_valid_scalar_type<T>::value, void>::type
155+
add_scalar_to_graph(const uint32_t fb_id, T value) {
156+
ValueRef ref = compute_graph_->add_scalar(value);
157+
ref_mapping_[fb_id] = ref;
158+
}
159+
160+
void add_string_to_graph(const uint32_t fb_id, VkValuePtr value) {
161+
const auto fb_str = value->value_as_String()->string_val();
162+
std::string string(fb_str->cbegin(), fb_str->cend());
163+
ValueRef ref = compute_graph_->add_string(std::move(string));
164+
ref_mapping_[fb_id] = ref;
165+
}
89166

167+
void add_value_to_graph(const uint32_t fb_id, VkValuePtr value) {
90168
ET_CHECK_MSG(
91-
vk_tensor->constant_id() >= 0,
92-
"Only constant buffers are supported when adding tensors to compute graph (indicated by constant_id < 0), but got constant_id of %d",
93-
vk_tensor->constant_id());
169+
!fb_id_exists(fb_id),
170+
"Trying to add a value that has already been added to the graph.");
171+
172+
switch (value->value_type()) {
173+
case vkgraph::GraphTypes::Int:
174+
add_scalar_to_graph(fb_id, value->value_as_Int()->int_val());
175+
break;
176+
case vkgraph::GraphTypes::Double:
177+
add_scalar_to_graph(fb_id, value->value_as_Double()->double_val());
178+
break;
179+
case vkgraph::GraphTypes::Bool:
180+
add_scalar_to_graph(fb_id, value->value_as_Bool()->bool_val());
181+
break;
182+
case vkgraph::GraphTypes::VkTensor:
183+
add_tensor_to_graph(fb_id, value->value_as_VkTensor());
184+
break;
185+
case vkgraph::GraphTypes::String:
186+
add_string_to_graph(fb_id, value);
187+
break;
188+
default:
189+
ET_CHECK_MSG(false, "Unsupported value type.");
190+
}
191+
}
94192

95-
const api::ScalarType& tensor_dtype =
96-
get_scalar_type(vk_tensor->datatype());
193+
void build_graph() {
194+
// First, add all values to the graph
195+
for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) {
196+
VkValuePtr value = flatbuffer_->values()->Get(fb_id);
197+
add_value_to_graph(fb_id, value);
198+
}
97199

98-
UIntVector tensor_dims_fb = vk_tensor->dims();
99-
const std::vector<int64_t> tensor_dims_vector(
100-
tensor_dims_fb->cbegin(), tensor_dims_fb->cend());
200+
// Parse the inputs
201+
for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
202+
const ValueRef ref = get_fb_id_valueref(fb_id);
203+
compute_graph_->set_input_tensor(ref);
204+
}
101205

102-
const uint8_t* tensor_data = getConstantDataPtr(
103-
flatbuffer_graph, vk_tensor->constant_id(), constant_data);
206+
// Parse the operators
207+
for (OpCallPtr op_call : *(flatbuffer_->chain())) {
208+
std::string op_name = op_call->name()->str();
209+
ET_CHECK_MSG(hasOpsFn(op_name), "Missing operator: %s", op_name.c_str());
104210

105-
const ValueRef value_ref = compute_graph->add_tensorref(
106-
tensor_dims_vector, tensor_dtype, tensor_data);
211+
const std::vector<int> arg_fb_ids(
212+
op_call->args()->cbegin(), op_call->args()->cend());
107213

108-
ref_mapping[value_id] = value_ref;
214+
std::vector<ValueRef> args;
215+
for (const int arg_fb_id : arg_fb_ids) {
216+
args.push_back(get_fb_id_valueref(arg_fb_id));
217+
}
109218

110-
return value_ref;
219+
auto vkFn = getOpsFn(op_name);
220+
vkFn(*compute_graph_, args);
221+
}
222+
223+
// Parse the outputs
224+
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
225+
const ValueRef ref = get_fb_id_valueref(fb_id);
226+
compute_graph_->set_output_tensor(ref);
227+
}
111228
}
229+
};
112230

113-
GraphConfig generate_config() const {
114-
const uint32_t submit_frequency = UINT32_MAX;
115-
116-
const api::CommandPoolConfig cmd_config{
117-
4u, // cmdPoolInitialSize
118-
2u, // cmdPoolBatchSize
119-
};
120-
121-
const api::DescriptorPoolConfig descriptor_pool_config{
122-
1024u, // descriptorPoolMaxSets
123-
1024u, // descriptorUniformBufferCount
124-
1024u, // descriptorStorageBufferCount
125-
1024u, // descriptorCombinedSamplerCount
126-
1024u, // descriptorStorageImageCount
127-
32u, // descriptorPileSizes
128-
};
129-
130-
const api::QueryPoolConfig query_pool_config{};
131-
132-
const api::ContextConfig context_config{
133-
submit_frequency, // cmdSubmitFrequency
134-
cmd_config, // cmdPoolConfig
135-
descriptor_pool_config, // descriptorPoolConfig
136-
query_pool_config, // queryPoolConfig
137-
};
138-
139-
const GraphConfig graph_config{
140-
context_config,
141-
};
142-
143-
return graph_config;
231+
class VulkanBackend final : public PyTorchBackendInterface {
232+
public:
233+
~VulkanBackend() override = default;
234+
235+
bool is_available() const override {
236+
return true;
144237
}
145238

146239
__ET_NODISCARD Error
147240
compileModel(const void* buffer_pointer, ComputeGraph* compute_graph) const {
148241
Result<VulkanDelegateHeader> header =
149242
VulkanDelegateHeader::Parse(buffer_pointer);
243+
150244
const uint8_t* flatbuffer_data = nullptr;
151245
const uint8_t* constant_data = nullptr;
152246

@@ -169,92 +263,10 @@ class VulkanBackend final : public PyTorchBackendInterface {
169263

170264
VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data);
171265

172-
// Mapping from serialized VkValue ids to compute graph ValueRefs
173-
// This will be populated as the compute graph is built
174-
std::unordered_map<uint32_t, ValueRef> ref_mapping;
175-
176-
// A vector which acts as a mapping from VkValue ids (vector indices) to
177-
// VkValues
178-
VkValuesVector value_mapping = flatbuffer_graph->values();
266+
GraphBuilder builder =
267+
GraphBuilder(compute_graph, flatbuffer_graph, constant_data);
179268

180-
// 1. Add all inputs (and corresponding tensors) to the compute graph
181-
UIntVector input_ids = flatbuffer_graph->input_ids();
182-
183-
for (size_t input_index = 0; input_index < input_ids->size();
184-
++input_index) {
185-
const uint32_t input_id = input_ids->Get(input_index);
186-
VkValuePtr input_vk_value = value_mapping->Get(input_id);
187-
188-
VkTensorPtr input_vk_tensor = input_vk_value->value();
189-
190-
ET_CHECK_MSG(
191-
input_vk_tensor->constant_id() < 0,
192-
"Expected constant buffer index for input at index %zu with id %d to be < 0 (since it is non-constant), but got: %d",
193-
input_index,
194-
input_id,
195-
input_vk_tensor->constant_id());
196-
197-
const api::ScalarType& input_dtype =
198-
get_scalar_type(input_vk_tensor->datatype());
199-
200-
UIntVector input_dims_fb = input_vk_tensor->dims();
201-
const std::vector<int64_t> input_dims_vector(
202-
input_dims_fb->cbegin(), input_dims_fb->cend());
203-
204-
const ValueRef input_ref = compute_graph->add_tensor(
205-
input_dims_vector, input_dtype, input_vk_tensor->mem_obj_id());
206-
207-
ref_mapping[input_id] = input_ref;
208-
compute_graph->set_input_tensor(input_ref);
209-
}
210-
211-
// 2. Add all ops to the graph
212-
// TODO: Generalize for ops that don't have 2 inputs and 1 output.
213-
for (OpCallPtr op_call : *(flatbuffer_graph->chain())) {
214-
std::string op_name = op_call->name()->str();
215-
216-
ET_CHECK_MSG(
217-
op_call->args() != nullptr && op_call->args()->size() == 3,
218-
"Vulkan currently only supports OperatorCall with 3 args");
219-
const auto arg_ids = op_call->args()->data();
220-
221-
const uint32_t input1_id = arg_ids[0];
222-
const uint32_t input2_id = arg_ids[1];
223-
const uint32_t output_id = arg_ids[2];
224-
225-
const ValueRef input1_ref = get_value_ref(
226-
input1_id,
227-
flatbuffer_graph,
228-
compute_graph,
229-
ref_mapping,
230-
value_mapping,
231-
constant_data);
232-
233-
const ValueRef input2_ref = get_value_ref(
234-
input2_id,
235-
flatbuffer_graph,
236-
compute_graph,
237-
ref_mapping,
238-
value_mapping,
239-
constant_data);
240-
241-
ET_CHECK_MSG(hasOpsFn(op_name), "Missing operator: %s", op_name.c_str());
242-
auto vkFn = getOpsFn(op_name);
243-
const at::native::vulkan::ValueRef output_ref = vkFn(
244-
*compute_graph,
245-
{input1_ref,
246-
input2_ref,
247-
1,
248-
value_mapping->Get(output_id)->value()->mem_obj_id()});
249-
250-
ref_mapping[output_id] = output_ref;
251-
}
252-
253-
// 3. Add all outputs to the compute graph
254-
for (const uint32_t output_id : *flatbuffer_graph->output_ids()) {
255-
const ValueRef output_ref = ref_mapping[output_id];
256-
compute_graph->set_output_tensor(output_ref);
257-
}
269+
builder.build_graph();
258270

259271
compute_graph->encode_prepack();
260272
compute_graph->prepack();

backends/vulkan/runtime/graph/ComputeGraph.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ ValueRef ComputeGraph::add_staging(
7777
return idx;
7878
}
7979

80+
ValueRef ComputeGraph::add_string(std::string&& str) {
81+
ValueRef idx(static_cast<int>(values_.size()));
82+
values_.emplace_back(std::move(str));
83+
return idx;
84+
}
85+
8086
ValueRef ComputeGraph::set_input_tensor(
8187
const ValueRef idx,
8288
const bool use_staging) {

0 commit comments

Comments
 (0)