Skip to content

Commit 21cbfd6

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Serialize list types from function args (#2404)
Summary: bypass-github-export-checks Pull Request resolved: #2404 In #2271, we already added - IntList - DoubleList - BoolList - ValueList to the schema and the runtime's Value class. Their serialization was incomplete missing two components: 1. Receiving a list in `torch.fx.Node.args`. 2. Receiving a non-tensor in `torch.fx.Node`. This change completes #1. Also, this change fixes a bug where values type `bool` matches both types `bool` and `int` and hence were being added twice. If our type support grows more complex, we can consider using our own types similar to the core Executorch runtime: https://github.com/pytorch/executorch/blob/689796499024fc4a133318d707f4c10db73da967/exir/emit/_emitter.py#L158-L166 ghstack-source-id: 218539049 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D54708353 fbshipit-source-id: 8641647b515e201ea63db67115c01c1532ad6566
1 parent 16a3156 commit 21cbfd6

File tree

4 files changed

+102
-14
lines changed

4 files changed

+102
-14
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ class GraphBuilder {
125125
ref_mapping_[fb_id] = ref;
126126
}
127127

128+
template <typename T>
129+
typename std::enable_if<is_valid_scalar_type<T>::value, void>::type
130+
add_scalar_list_to_graph(const uint32_t fb_id, std::vector<T>&& value) {
131+
ValueRef ref = compute_graph_->add_scalar_list(std::move(value));
132+
ref_mapping_[fb_id] = ref;
133+
}
134+
135+
void add_value_list_to_graph(
136+
const uint32_t fb_id,
137+
std::vector<ValueRef>&& value) {
138+
ValueRef ref = compute_graph_->add_value_list(std::move(value));
139+
ref_mapping_[fb_id] = ref;
140+
}
141+
128142
void add_string_to_graph(const uint32_t fb_id, VkValuePtr value) {
129143
const auto fb_str = value->value_as_String()->string_val();
130144
std::string string(fb_str->cbegin(), fb_str->cend());
@@ -150,6 +164,34 @@ class GraphBuilder {
150164
case vkgraph::GraphTypes::VkTensor:
151165
add_tensor_to_graph(fb_id, value->value_as_VkTensor());
152166
break;
167+
case vkgraph::GraphTypes::IntList:
168+
add_scalar_list_to_graph(
169+
fb_id,
170+
std::vector<int64_t>(
171+
value->value_as_IntList()->items()->cbegin(),
172+
value->value_as_IntList()->items()->cend()));
173+
break;
174+
case vkgraph::GraphTypes::DoubleList:
175+
add_scalar_list_to_graph(
176+
fb_id,
177+
std::vector<double>(
178+
value->value_as_DoubleList()->items()->cbegin(),
179+
value->value_as_DoubleList()->items()->cend()));
180+
break;
181+
case vkgraph::GraphTypes::BoolList:
182+
add_scalar_list_to_graph(
183+
fb_id,
184+
std::vector<bool>(
185+
value->value_as_BoolList()->items()->cbegin(),
186+
value->value_as_BoolList()->items()->cend()));
187+
break;
188+
case vkgraph::GraphTypes::ValueList:
189+
add_value_list_to_graph(
190+
fb_id,
191+
std::vector<ValueRef>(
192+
value->value_as_ValueList()->items()->cbegin(),
193+
value->value_as_ValueList()->items()->cend()));
194+
break;
153195
case vkgraph::GraphTypes::String:
154196
add_string_to_graph(fb_id, value);
155197
break;

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ ValueRef ComputeGraph::add_staging(
122122
return idx;
123123
}
124124

125+
ValueRef ComputeGraph::add_value_list(std::vector<ValueRef>&& value) {
126+
ValueRef idx(static_cast<int>(values_.size()));
127+
values_.emplace_back(std::move(value));
128+
return idx;
129+
}
130+
125131
ValueRef ComputeGraph::add_string(std::string&& str) {
126132
ValueRef idx(static_cast<int>(values_.size()));
127133
values_.emplace_back(std::move(str));

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,13 @@ class ComputeGraph final {
143143

144144
template <typename T>
145145
typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
146-
add_scalar_list(std::vector<T>&& values);
146+
add_scalar(T value);
147147

148148
template <typename T>
149149
typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
150-
add_scalar(T value);
150+
add_scalar_list(std::vector<T>&& value);
151+
152+
ValueRef add_value_list(std::vector<ValueRef>&& value);
151153

152154
ValueRef add_string(std::string&& str);
153155

@@ -212,17 +214,17 @@ class ComputeGraph final {
212214

213215
template <typename T>
214216
inline typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
215-
ComputeGraph::add_scalar_list(std::vector<T>&& values) {
217+
ComputeGraph::add_scalar(T value) {
216218
ValueRef idx(static_cast<int>(values_.size()));
217-
values_.emplace_back(std::move(values));
219+
values_.emplace_back(value);
218220
return idx;
219221
}
220222

221223
template <typename T>
222224
inline typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
223-
ComputeGraph::add_scalar(T value) {
225+
ComputeGraph::add_scalar_list(std::vector<T>&& value) {
224226
ValueRef idx(static_cast<int>(values_.size()));
225-
values_.emplace_back(value);
227+
values_.emplace_back(std::move(value));
226228
return idx;
227229
}
228230

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Optional, Union
7+
from typing import cast, List, Optional, Union
88

99
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
1010

@@ -15,8 +15,8 @@
1515
from torch.export import ExportedProgram
1616
from torch.fx import Node
1717

18-
_ScalarType = Union[int, bool, float]
19-
_Argument = Union[Node, int, bool, float, str]
18+
_ScalarType = Union[bool, int, float]
19+
_Argument = Union[Node, List[Node], _ScalarType, List[_ScalarType], str]
2020

2121

2222
class VkGraphBuilder:
@@ -150,14 +150,46 @@ def create_tensor_values(self, node: Node) -> int:
150150
"Creating values for nodes with collection types is not supported yet."
151151
)
152152

153+
def create_value_list_value(self, arg: List[Node]) -> int:
154+
self.values.append(
155+
vk_graph_schema.VkValue(
156+
vk_graph_schema.ValueList(
157+
items=[self.get_or_create_value_for(e) for e in arg]
158+
)
159+
)
160+
)
161+
return len(self.values) - 1
162+
153163
def create_scalar_value(self, scalar: _ScalarType) -> int:
154164
new_id = len(self.values)
155-
if isinstance(scalar, int):
156-
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar)))
157-
if isinstance(scalar, float):
158-
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
159165
if isinstance(scalar, bool):
160166
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar)))
167+
elif isinstance(scalar, int):
168+
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar)))
169+
elif isinstance(scalar, float):
170+
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
171+
return new_id
172+
173+
def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
174+
new_id = len(self.values)
175+
if isinstance(arg[0], bool):
176+
self.values.append(
177+
vk_graph_schema.VkValue(
178+
vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg])
179+
)
180+
)
181+
elif isinstance(arg[0], int):
182+
self.values.append(
183+
vk_graph_schema.VkValue(
184+
vk_graph_schema.IntList(items=[cast(int, e) for e in arg])
185+
)
186+
)
187+
elif isinstance(arg[0], float):
188+
self.values.append(
189+
vk_graph_schema.VkValue(
190+
vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg])
191+
)
192+
)
161193
return new_id
162194

163195
def create_string_value(self, string: str) -> int:
@@ -174,8 +206,14 @@ def get_or_create_value_for(self, arg: _Argument):
174206
return self.node_to_value_ids[arg]
175207
# Return id for a newly created value
176208
return self.create_tensor_values(arg)
177-
elif isinstance(arg, (int, float, bool)):
209+
elif isinstance(arg, list) and isinstance(arg[0], Node):
210+
# pyre-ignore[6]
211+
return self.create_value_list_value(arg)
212+
elif isinstance(arg, _ScalarType):
178213
return self.create_scalar_value(arg)
214+
elif isinstance(arg, list) and isinstance(arg[0], _ScalarType):
215+
# pyre-ignore[6]
216+
return self.create_scalar_list_value(arg)
179217
elif isinstance(arg, str):
180218
return self.create_string_value(arg)
181219
else:

0 commit comments

Comments
 (0)