Skip to content

Commit 4321bc5

Browse files
committed
[ET-VK] Serialize tuple types from Node
In #2271, we already added - IntList - DoubleList - BoolList - ValueList to the schema and the runtime's Value class. Their serialization was incomplete missing two components: 1. Receiving a list in `torch.fx.Node.args`. 2. Receiving a non-tensor in `torch.fx.Node`. This change completes #2. Also, we introduce a specific handler for `getitem` nodes as it is required. Every `function_call` outputting non-tensor `torch.fx.Node` is followed by a special `getitem` `function_call`. Differential Revision: [D54789099](https://our.internmc.facebook.com/intern/diff/D54789099/) ghstack-source-id: 218541429 Pull Request resolved: #2405
1 parent a197584 commit 4321bc5

File tree

2 files changed

+48
-34
lines changed

2 files changed

+48
-34
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import operator
78
from typing import final, List, Optional
89

910
import torch
@@ -30,6 +31,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
3031
exir_ops.edge.aten.mul.Tensor,
3132
exir_ops.edge.aten.sub.Tensor,
3233
exir_ops.edge.aten.pow.Tensor_Tensor,
34+
operator.getitem,
3335
]
3436
return supported
3537

backends/vulkan/serialization/vulkan_graph_builder.py

+46-34
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import operator
78
from typing import cast, List, Optional, Union
89

910
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
@@ -16,7 +17,7 @@
1617
from torch.fx import Node
1718

1819
_ScalarType = Union[bool, int, float]
19-
_Argument = Union[Node, List[Node], _ScalarType, List[_ScalarType], str]
20+
_Argument = Union[Node, List[Node], TensorSpec, _ScalarType, List[_ScalarType], str]
2021

2122

2223
class VkGraphBuilder:
@@ -34,6 +35,7 @@ def __init__(self, program: ExportedProgram) -> None:
3435

3536
@staticmethod
3637
def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
38+
# TODO(T182302927): Support more dtypes including float16, int(32|64).
3739
if torch_dtype == torch.float32:
3840
return vk_graph_schema.VkDataType.fp32
3941
else:
@@ -102,33 +104,20 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
102104
return tensor
103105

104106
def maybe_add_constant_tensor(self, node: Node) -> int:
105-
const_buffer_idx = -1
107+
constant_id = -1
106108
if self.is_param_node(node):
107-
const_buffer_idx = len(self.const_tensors)
109+
constant_id = len(self.const_tensors)
108110
self.const_tensors.append(self.get_param_tensor(node))
109111

110-
return const_buffer_idx
111-
112-
def create_single_tensor_value(self, node: Node) -> int:
113-
constant_id = self.maybe_add_constant_tensor(node)
114-
115-
spec = node.meta.get("spec")
116-
assert isinstance(spec, TensorSpec)
117-
new_id = len(self.values)
118-
if node not in self.node_to_value_ids:
119-
self.node_to_value_ids[node] = new_id
120-
else:
121-
current_ids = self.node_to_value_ids[node]
122-
if isinstance(current_ids, int):
123-
current_ids = [current_ids, new_id]
124-
else:
125-
current_ids.append(new_id)
112+
return constant_id
126113

114+
def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
127115
# Negative id indicates that this tensor will have its own dedicated memory.
128116
mem_obj_id = -1
129117
if spec.mem_obj_id is not None:
130118
mem_obj_id = spec.mem_obj_id
131119

120+
new_id = len(self.values)
132121
self.values.append(
133122
vk_graph_schema.VkValue(
134123
value=vk_graph_schema.VkTensor(
@@ -141,16 +130,23 @@ def create_single_tensor_value(self, node: Node) -> int:
141130
)
142131
return new_id
143132

144-
def create_tensor_values(self, node: Node) -> int:
133+
def create_node_value(self, node: Node) -> int:
145134
spec = node.meta.get("spec")
146135
if isinstance(spec, TensorSpec):
147-
return self.create_single_tensor_value(node)
136+
constant_id = self.maybe_add_constant_tensor(node)
137+
new_id = self.create_tensor_value(spec, constant_id)
138+
self.node_to_value_ids[node] = new_id
139+
return new_id
140+
elif isinstance(spec, tuple):
141+
# Create a Value for each element in the tuple, wrap Values in a
142+
# ValueList, and map the Node to the ValueList id.
143+
new_id = self.create_value_list_value(spec)
144+
self.node_to_value_ids[node] = new_id
145+
return new_id
148146
else:
149-
raise RuntimeError(
150-
"Creating values for nodes with collection types is not supported yet."
151-
)
147+
raise RuntimeError(f"Cannot create value for spec of type {type(spec)}")
152148

153-
def create_value_list_value(self, arg: List[Node]) -> int:
149+
def create_value_list_value(self, arg: List[Node] | tuple) -> int:
154150
self.values.append(
155151
vk_graph_schema.VkValue(
156152
vk_graph_schema.ValueList(
@@ -201,14 +197,15 @@ def create_string_value(self, string: str) -> int:
201197

202198
def get_or_create_value_for(self, arg: _Argument):
203199
if isinstance(arg, Node):
204-
# If the value has already been created, return the existing id
200+
# If the Node has already been processed, return the existing id.
205201
if arg in self.node_to_value_ids:
206202
return self.node_to_value_ids[arg]
207-
# Return id for a newly created value
208-
return self.create_tensor_values(arg)
203+
return self.create_node_value(arg)
209204
elif isinstance(arg, list) and isinstance(arg[0], Node):
210205
# pyre-ignore[6]
211206
return self.create_value_list_value(arg)
207+
elif isinstance(arg, TensorSpec):
208+
return self.create_tensor_value(arg)
212209
elif isinstance(arg, _ScalarType):
213210
return self.create_scalar_value(arg)
214211
elif isinstance(arg, list) and isinstance(arg[0], _ScalarType):
@@ -220,13 +217,25 @@ def get_or_create_value_for(self, arg: _Argument):
220217
raise RuntimeError(f"Cannot create value for arg of type {type(arg)}")
221218

222219
def process_placeholder_node(self, node: Node) -> None:
223-
ids = self.create_tensor_values(node)
220+
ids = self.create_node_value(node)
224221
if not self.is_param_node(node):
225222
if isinstance(ids, int):
226223
self.input_ids.append(ids)
227224
else:
228225
self.input_ids += ids
229226

227+
def process_getitem_node(self, node: Node) -> None:
228+
# Find ValueList id from the collection node.
229+
collection_node = node.all_input_nodes[0]
230+
list_id = self.node_to_value_ids[collection_node]
231+
232+
# Extract the target Value id from ValueList.
233+
valuelist_id = node.args[1]
234+
value_id = self.values[list_id].value.items[valuelist_id]
235+
236+
# Map Node to Value id.
237+
self.node_to_value_ids[node] = value_id
238+
230239
def process_call_function_node(self, node) -> None:
231240
operator_call_args = []
232241

@@ -238,12 +247,12 @@ def process_call_function_node(self, node) -> None:
238247
else:
239248
function_arg = schema_arg.default_value
240249

241-
# Create a value for each function argument. If the argument has been
242-
# previously encountered, then use the existing value id.
250+
# Create a Value for each function argument. If the argument has been
251+
# previously encountered, then use the existing Value id.
243252
operator_call_args.append(self.get_or_create_value_for(function_arg))
244253

245254
# Add output node
246-
operator_call_args.append(self.create_tensor_values(node))
255+
operator_call_args.append(self.create_node_value(node))
247256

248257
self.chain.append(
249258
vk_graph_schema.OperatorCall(
@@ -253,7 +262,7 @@ def process_call_function_node(self, node) -> None:
253262
)
254263

255264
def process_getattr_node(self, node: Node) -> None:
256-
self.create_tensor_values(node)
265+
self.create_node_value(node)
257266

258267
def process_output_node(self, node: Node) -> None:
259268
for out_node in node.all_input_nodes:
@@ -269,7 +278,10 @@ def process_node(self, node: Node) -> None:
269278
if node.op == "placeholder":
270279
self.process_placeholder_node(node)
271280
elif node.op == "call_function":
272-
self.process_call_function_node(node)
281+
if node.target == operator.getitem:
282+
self.process_getitem_node(node)
283+
else:
284+
self.process_call_function_node(node)
273285
elif node.op == "get_attr":
274286
self.process_getattr_node(node)
275287
elif node.op == "output":

0 commit comments

Comments
 (0)