23
23
#include < cstdlib> /* strtol */
24
24
#include < memory>
25
25
#include < type_traits>
26
+ #include < vector>
26
27
27
28
namespace torch {
28
29
namespace executor {
29
30
namespace vulkan {
30
31
namespace {
31
32
33
+ using namespace at ::native::vulkan;
34
+
32
35
// Flatbuffer types
33
36
using VkGraphPtr = const vkgraph::VkGraph*;
34
37
using OpCallPtr = const vkgraph::OperatorCall*;
@@ -51,102 +54,193 @@ const uint8_t* getConstantDataPtr(
51
54
return constant_data + constant_bytes->offset ();
52
55
}
53
56
54
- using namespace at ::native::vulkan;
57
+ api::ScalarType get_scalar_type (const vkgraph::VkDataType& vk_datatype) {
58
+ switch (vk_datatype) {
59
+ case (vkgraph::VkDataType::fp32): {
60
+ return api::kFloat ;
61
+ }
62
+ }
63
+ }
64
+
65
+ GraphConfig generate_config () {
66
+ const uint32_t submit_frequency = UINT32_MAX;
67
+
68
+ const api::CommandPoolConfig cmd_config{
69
+ 4u , // cmdPoolInitialSize
70
+ 2u , // cmdPoolBatchSize
71
+ };
72
+
73
+ const api::DescriptorPoolConfig descriptor_pool_config{
74
+ 1024u , // descriptorPoolMaxSets
75
+ 1024u , // descriptorUniformBufferCount
76
+ 1024u , // descriptorStorageBufferCount
77
+ 1024u , // descriptorCombinedSamplerCount
78
+ 1024u , // descriptorStorageImageCount
79
+ 32u , // descriptorPileSizes
80
+ };
81
+
82
+ const api::QueryPoolConfig query_pool_config{};
83
+
84
+ const api::ContextConfig context_config{
85
+ submit_frequency, // cmdSubmitFrequency
86
+ cmd_config, // cmdPoolConfig
87
+ descriptor_pool_config, // descriptorPoolConfig
88
+ query_pool_config, // queryPoolConfig
89
+ };
90
+
91
+ const GraphConfig graph_config{
92
+ context_config,
93
+ };
94
+
95
+ return graph_config;
96
+ }
97
+
98
+ class GraphBuilder {
99
+ ComputeGraph* compute_graph_;
100
+ VkGraphPtr flatbuffer_;
101
+ const uint8_t * constant_data_;
102
+
103
+ std::unordered_map<uint32_t , ValueRef> ref_mapping_;
55
104
56
- class VulkanBackend final : public PyTorchBackendInterface {
57
105
public:
58
- ~VulkanBackend () override = default ;
106
+ explicit GraphBuilder (
107
+ ComputeGraph* compute_graph,
108
+ VkGraphPtr flatbuffer,
109
+ const uint8_t * constant_data)
110
+ : compute_graph_(compute_graph),
111
+ flatbuffer_(flatbuffer),
112
+ constant_data_(constant_data),
113
+ ref_mapping_() {}
114
+
115
+ bool fb_id_exists (const uint32_t fb_id) {
116
+ const std::unordered_map<uint32_t , ValueRef>::iterator found_ref =
117
+ ref_mapping_.find (fb_id);
59
118
60
- bool is_available () const override {
61
- return true ;
119
+ return found_ref != ref_mapping_.end ();
62
120
}
63
121
64
- api::ScalarType get_scalar_type (
65
- const vkgraph::VkDataType& vk_datatype) const {
66
- switch (vk_datatype) {
67
- case (vkgraph::VkDataType::fp32): {
68
- return api::kFloat ;
69
- }
70
- }
122
+ ValueRef get_fb_id_valueref (const uint32_t fb_id) {
123
+ const std::unordered_map<uint32_t , ValueRef>::iterator found_ref =
124
+ ref_mapping_.find (fb_id);
125
+
126
+ ET_CHECK_MSG (
127
+ found_ref != ref_mapping_.end (),
128
+ " Trying to extract a value that hasn't yet been added to the graph." );
129
+
130
+ return found_ref->second ;
71
131
}
72
132
73
- ValueRef get_value_ref (
74
- const uint32_t value_id,
75
- VkGraphPtr flatbuffer_graph,
76
- ComputeGraph* compute_graph,
77
- std::unordered_map<uint32_t , ValueRef>& ref_mapping,
78
- VkValuesVector value_mapping,
79
- const uint8_t * constant_data) const {
80
- const std::unordered_map<uint32_t , ValueRef>::iterator found_ref =
81
- ref_mapping.find (value_id);
133
+ void add_tensor_to_graph (const uint32_t fb_id, VkTensorPtr tensor_fb) {
134
+ const api::ScalarType& dtype = get_scalar_type (tensor_fb->datatype ());
135
+
136
+ UIntVector dims_fb = tensor_fb->dims ();
137
+ const std::vector<int64_t > dims_vector (dims_fb->cbegin (), dims_fb->cend ());
82
138
83
- if (found_ref != ref_mapping.end ()) {
84
- return found_ref->second ;
139
+ ValueRef ref;
140
+ if (tensor_fb->constant_id () >= 0 ) {
141
+ const uint8_t * tensor_data = getConstantDataPtr (
142
+ flatbuffer_, tensor_fb->constant_id (), constant_data_);
143
+
144
+ ref = compute_graph_->add_tensorref (dims_vector, dtype, tensor_data);
145
+ } else {
146
+ ref = compute_graph_->add_tensor (
147
+ dims_vector, dtype, tensor_fb->mem_obj_id ());
85
148
}
86
149
87
- VkValuePtr vk_value = value_mapping->Get (value_id);
88
- VkTensorPtr vk_tensor = vk_value->value ();
150
+ ref_mapping_[fb_id] = ref;
151
+ }
152
+
153
+ template <typename T>
154
+ typename std::enable_if<is_valid_scalar_type<T>::value, void >::type
155
+ add_scalar_to_graph (const uint32_t fb_id, T value) {
156
+ ValueRef ref = compute_graph_->add_scalar (value);
157
+ ref_mapping_[fb_id] = ref;
158
+ }
159
+
160
+ void add_string_to_graph (const uint32_t fb_id, VkValuePtr value) {
161
+ const auto fb_str = value->value_as_String ()->string_val ();
162
+ std::string string (fb_str->cbegin (), fb_str->cend ());
163
+ ValueRef ref = compute_graph_->add_string (std::move (string));
164
+ ref_mapping_[fb_id] = ref;
165
+ }
89
166
167
+ void add_value_to_graph (const uint32_t fb_id, VkValuePtr value) {
90
168
ET_CHECK_MSG (
91
- vk_tensor->constant_id () >= 0 ,
92
- " Only constant buffers are supported when adding tensors to compute graph (indicated by constant_id < 0), but got constant_id of %d" ,
93
- vk_tensor->constant_id ());
169
+ !fb_id_exists (fb_id),
170
+ " Trying to add a value that has already been added to the graph." );
171
+
172
+ switch (value->value_type ()) {
173
+ case vkgraph::GraphTypes::Int:
174
+ add_scalar_to_graph (fb_id, value->value_as_Int ()->int_val ());
175
+ break ;
176
+ case vkgraph::GraphTypes::Double:
177
+ add_scalar_to_graph (fb_id, value->value_as_Double ()->double_val ());
178
+ break ;
179
+ case vkgraph::GraphTypes::Bool:
180
+ add_scalar_to_graph (fb_id, value->value_as_Bool ()->bool_val ());
181
+ break ;
182
+ case vkgraph::GraphTypes::VkTensor:
183
+ add_tensor_to_graph (fb_id, value->value_as_VkTensor ());
184
+ break ;
185
+ case vkgraph::GraphTypes::String:
186
+ add_string_to_graph (fb_id, value);
187
+ break ;
188
+ default :
189
+ ET_CHECK_MSG (false , " Unsupported value type." );
190
+ }
191
+ }
94
192
95
- const api::ScalarType& tensor_dtype =
96
- get_scalar_type (vk_tensor->datatype ());
193
+ void build_graph () {
194
+ // First, add all values to the graph
195
+ for (uint32_t fb_id = 0 ; fb_id < flatbuffer_->values ()->size (); ++fb_id) {
196
+ VkValuePtr value = flatbuffer_->values ()->Get (fb_id);
197
+ add_value_to_graph (fb_id, value);
198
+ }
97
199
98
- UIntVector tensor_dims_fb = vk_tensor->dims ();
99
- const std::vector<int64_t > tensor_dims_vector (
100
- tensor_dims_fb->cbegin (), tensor_dims_fb->cend ());
200
+ // Parse the inputs
201
+ for (const uint32_t fb_id : *flatbuffer_->input_ids ()) {
202
+ const ValueRef ref = get_fb_id_valueref (fb_id);
203
+ compute_graph_->set_input_tensor (ref);
204
+ }
101
205
102
- const uint8_t * tensor_data = getConstantDataPtr (
103
- flatbuffer_graph, vk_tensor->constant_id (), constant_data);
206
+ // Parse the operators
207
+ for (OpCallPtr op_call : *(flatbuffer_->chain ())) {
208
+ std::string op_name = op_call->name ()->str ();
209
+ ET_CHECK_MSG (hasOpsFn (op_name), " Missing operator: %s" , op_name.c_str ());
104
210
105
- const ValueRef value_ref = compute_graph-> add_tensorref (
106
- tensor_dims_vector, tensor_dtype, tensor_data );
211
+ const std::vector< int > arg_fb_ids (
212
+ op_call-> args ()-> cbegin (), op_call-> args ()-> cend () );
107
213
108
- ref_mapping[value_id] = value_ref;
214
+ std::vector<ValueRef> args;
215
+ for (const int arg_fb_id : arg_fb_ids) {
216
+ args.push_back (get_fb_id_valueref (arg_fb_id));
217
+ }
109
218
110
- return value_ref;
219
+ auto vkFn = getOpsFn (op_name);
220
+ vkFn (*compute_graph_, args);
221
+ }
222
+
223
+ // Parse the outputs
224
+ for (const uint32_t fb_id : *flatbuffer_->output_ids ()) {
225
+ const ValueRef ref = get_fb_id_valueref (fb_id);
226
+ compute_graph_->set_output_tensor (ref);
227
+ }
111
228
}
229
+ };
112
230
113
- GraphConfig generate_config () const {
114
- const uint32_t submit_frequency = UINT32_MAX;
115
-
116
- const api::CommandPoolConfig cmd_config{
117
- 4u , // cmdPoolInitialSize
118
- 2u , // cmdPoolBatchSize
119
- };
120
-
121
- const api::DescriptorPoolConfig descriptor_pool_config{
122
- 1024u , // descriptorPoolMaxSets
123
- 1024u , // descriptorUniformBufferCount
124
- 1024u , // descriptorStorageBufferCount
125
- 1024u , // descriptorCombinedSamplerCount
126
- 1024u , // descriptorStorageImageCount
127
- 32u , // descriptorPileSizes
128
- };
129
-
130
- const api::QueryPoolConfig query_pool_config{};
131
-
132
- const api::ContextConfig context_config{
133
- submit_frequency, // cmdSubmitFrequency
134
- cmd_config, // cmdPoolConfig
135
- descriptor_pool_config, // descriptorPoolConfig
136
- query_pool_config, // queryPoolConfig
137
- };
138
-
139
- const GraphConfig graph_config{
140
- context_config,
141
- };
142
-
143
- return graph_config;
231
+ class VulkanBackend final : public PyTorchBackendInterface {
232
+ public:
233
+ ~VulkanBackend () override = default ;
234
+
235
+ bool is_available () const override {
236
+ return true ;
144
237
}
145
238
146
239
__ET_NODISCARD Error
147
240
compileModel (const void * buffer_pointer, ComputeGraph* compute_graph) const {
148
241
Result<VulkanDelegateHeader> header =
149
242
VulkanDelegateHeader::Parse (buffer_pointer);
243
+
150
244
const uint8_t * flatbuffer_data = nullptr ;
151
245
const uint8_t * constant_data = nullptr ;
152
246
@@ -169,92 +263,10 @@ class VulkanBackend final : public PyTorchBackendInterface {
169
263
170
264
VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph (flatbuffer_data);
171
265
172
- // Mapping from serialized VkValue ids to compute graph ValueRefs
173
- // This will be populated as the compute graph is built
174
- std::unordered_map<uint32_t , ValueRef> ref_mapping;
175
-
176
- // A vector which acts as a mapping from VkValue ids (vector indices) to
177
- // VkValues
178
- VkValuesVector value_mapping = flatbuffer_graph->values ();
266
+ GraphBuilder builder =
267
+ GraphBuilder (compute_graph, flatbuffer_graph, constant_data);
179
268
180
- // 1. Add all inputs (and corresponding tensors) to the compute graph
181
- UIntVector input_ids = flatbuffer_graph->input_ids ();
182
-
183
- for (size_t input_index = 0 ; input_index < input_ids->size ();
184
- ++input_index) {
185
- const uint32_t input_id = input_ids->Get (input_index);
186
- VkValuePtr input_vk_value = value_mapping->Get (input_id);
187
-
188
- VkTensorPtr input_vk_tensor = input_vk_value->value ();
189
-
190
- ET_CHECK_MSG (
191
- input_vk_tensor->constant_id () < 0 ,
192
- " Expected constant buffer index for input at index %zu with id %d to be < 0 (since it is non-constant), but got: %d" ,
193
- input_index,
194
- input_id,
195
- input_vk_tensor->constant_id ());
196
-
197
- const api::ScalarType& input_dtype =
198
- get_scalar_type (input_vk_tensor->datatype ());
199
-
200
- UIntVector input_dims_fb = input_vk_tensor->dims ();
201
- const std::vector<int64_t > input_dims_vector (
202
- input_dims_fb->cbegin (), input_dims_fb->cend ());
203
-
204
- const ValueRef input_ref = compute_graph->add_tensor (
205
- input_dims_vector, input_dtype, input_vk_tensor->mem_obj_id ());
206
-
207
- ref_mapping[input_id] = input_ref;
208
- compute_graph->set_input_tensor (input_ref);
209
- }
210
-
211
- // 2. Add all ops to the graph
212
- // TODO: Generalize for ops that don't have 2 inputs and 1 output.
213
- for (OpCallPtr op_call : *(flatbuffer_graph->chain ())) {
214
- std::string op_name = op_call->name ()->str ();
215
-
216
- ET_CHECK_MSG (
217
- op_call->args () != nullptr && op_call->args ()->size () == 3 ,
218
- " Vulkan currently only supports OperatorCall with 3 args" );
219
- const auto arg_ids = op_call->args ()->data ();
220
-
221
- const uint32_t input1_id = arg_ids[0 ];
222
- const uint32_t input2_id = arg_ids[1 ];
223
- const uint32_t output_id = arg_ids[2 ];
224
-
225
- const ValueRef input1_ref = get_value_ref (
226
- input1_id,
227
- flatbuffer_graph,
228
- compute_graph,
229
- ref_mapping,
230
- value_mapping,
231
- constant_data);
232
-
233
- const ValueRef input2_ref = get_value_ref (
234
- input2_id,
235
- flatbuffer_graph,
236
- compute_graph,
237
- ref_mapping,
238
- value_mapping,
239
- constant_data);
240
-
241
- ET_CHECK_MSG (hasOpsFn (op_name), " Missing operator: %s" , op_name.c_str ());
242
- auto vkFn = getOpsFn (op_name);
243
- const at::native::vulkan::ValueRef output_ref = vkFn (
244
- *compute_graph,
245
- {input1_ref,
246
- input2_ref,
247
- 1 ,
248
- value_mapping->Get (output_id)->value ()->mem_obj_id ()});
249
-
250
- ref_mapping[output_id] = output_ref;
251
- }
252
-
253
- // 3. Add all outputs to the compute graph
254
- for (const uint32_t output_id : *flatbuffer_graph->output_ids ()) {
255
- const ValueRef output_ref = ref_mapping[output_id];
256
- compute_graph->set_output_tensor (output_ref);
257
- }
269
+ builder.build_graph ();
258
270
259
271
compute_graph->encode_prepack ();
260
272
compute_graph->prepack ();
0 commit comments