4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import operator
7
8
from typing import cast , List , Optional , Union
8
9
9
10
import executorch .backends .vulkan .serialization .vulkan_graph_schema as vk_graph_schema
16
17
from torch .fx import Node
17
18
18
19
_ScalarType = Union [bool , int , float ]
19
- _Argument = Union [Node , List [Node ], _ScalarType , List [_ScalarType ], str ]
20
+ _Argument = Union [Node , List [Node ], TensorSpec , _ScalarType , List [_ScalarType ], str ]
20
21
21
22
22
23
class VkGraphBuilder :
@@ -34,6 +35,7 @@ def __init__(self, program: ExportedProgram) -> None:
34
35
35
36
@staticmethod
36
37
def get_vk_datatype (torch_dtype : torch .dtype ) -> vk_graph_schema .VkDataType :
38
+ # TODO(T182302927): Support more dtypes including float16, int(32|64).
37
39
if torch_dtype == torch .float32 :
38
40
return vk_graph_schema .VkDataType .fp32
39
41
else :
@@ -102,33 +104,20 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
102
104
return tensor
103
105
104
106
def maybe_add_constant_tensor (self , node : Node ) -> int :
105
- const_buffer_idx = - 1
107
+ constant_id = - 1
106
108
if self .is_param_node (node ):
107
- const_buffer_idx = len (self .const_tensors )
109
+ constant_id = len (self .const_tensors )
108
110
self .const_tensors .append (self .get_param_tensor (node ))
109
111
110
- return const_buffer_idx
111
-
112
- def create_single_tensor_value (self , node : Node ) -> int :
113
- constant_id = self .maybe_add_constant_tensor (node )
114
-
115
- spec = node .meta .get ("spec" )
116
- assert isinstance (spec , TensorSpec )
117
- new_id = len (self .values )
118
- if node not in self .node_to_value_ids :
119
- self .node_to_value_ids [node ] = new_id
120
- else :
121
- current_ids = self .node_to_value_ids [node ]
122
- if isinstance (current_ids , int ):
123
- current_ids = [current_ids , new_id ]
124
- else :
125
- current_ids .append (new_id )
112
+ return constant_id
126
113
114
+ def create_tensor_value (self , spec : TensorSpec , constant_id : int = - 1 ) -> int :
127
115
# Negative id indicates that this tensor will have its own dedicated memory.
128
116
mem_obj_id = - 1
129
117
if spec .mem_obj_id is not None :
130
118
mem_obj_id = spec .mem_obj_id
131
119
120
+ new_id = len (self .values )
132
121
self .values .append (
133
122
vk_graph_schema .VkValue (
134
123
value = vk_graph_schema .VkTensor (
@@ -141,16 +130,23 @@ def create_single_tensor_value(self, node: Node) -> int:
141
130
)
142
131
return new_id
143
132
144
- def create_tensor_values (self , node : Node ) -> int :
133
+ def create_node_value (self , node : Node ) -> int :
145
134
spec = node .meta .get ("spec" )
146
135
if isinstance (spec , TensorSpec ):
147
- return self .create_single_tensor_value (node )
136
+ constant_id = self .maybe_add_constant_tensor (node )
137
+ new_id = self .create_tensor_value (spec , constant_id )
138
+ self .node_to_value_ids [node ] = new_id
139
+ return new_id
140
+ elif isinstance (spec , tuple ):
141
+ # Create a Value for each element in the tuple, wrap Values in a
142
+ # ValueList, and map the Node to the ValueList id.
143
+ new_id = self .create_value_list_value (spec )
144
+ self .node_to_value_ids [node ] = new_id
145
+ return new_id
148
146
else :
149
- raise RuntimeError (
150
- "Creating values for nodes with collection types is not supported yet."
151
- )
147
+ raise RuntimeError (f"Cannot create value for spec of type { type (spec )} " )
152
148
153
- def create_value_list_value (self , arg : List [Node ]) -> int :
149
+ def create_value_list_value (self , arg : List [Node ] | tuple ) -> int :
154
150
self .values .append (
155
151
vk_graph_schema .VkValue (
156
152
vk_graph_schema .ValueList (
@@ -201,14 +197,15 @@ def create_string_value(self, string: str) -> int:
201
197
202
198
def get_or_create_value_for (self , arg : _Argument ):
203
199
if isinstance (arg , Node ):
204
- # If the value has already been created , return the existing id
200
+ # If the Node has already been processed , return the existing id.
205
201
if arg in self .node_to_value_ids :
206
202
return self .node_to_value_ids [arg ]
207
- # Return id for a newly created value
208
- return self .create_tensor_values (arg )
203
+ return self .create_node_value (arg )
209
204
elif isinstance (arg , list ) and isinstance (arg [0 ], Node ):
210
205
# pyre-ignore[6]
211
206
return self .create_value_list_value (arg )
207
+ elif isinstance (arg , TensorSpec ):
208
+ return self .create_tensor_value (arg )
212
209
elif isinstance (arg , _ScalarType ):
213
210
return self .create_scalar_value (arg )
214
211
elif isinstance (arg , list ) and isinstance (arg [0 ], _ScalarType ):
@@ -220,13 +217,25 @@ def get_or_create_value_for(self, arg: _Argument):
220
217
raise RuntimeError (f"Cannot create value for arg of type { type (arg )} " )
221
218
222
219
def process_placeholder_node (self , node : Node ) -> None :
223
- ids = self .create_tensor_values (node )
220
+ ids = self .create_node_value (node )
224
221
if not self .is_param_node (node ):
225
222
if isinstance (ids , int ):
226
223
self .input_ids .append (ids )
227
224
else :
228
225
self .input_ids += ids
229
226
227
+ def process_getitem_node (self , node : Node ) -> None :
228
+ # Find ValueList id from the collection node.
229
+ collection_node = node .all_input_nodes [0 ]
230
+ list_id = self .node_to_value_ids [collection_node ]
231
+
232
+ # Extract the target Value id from ValueList.
233
+ valuelist_id = node .args [1 ]
234
+ value_id = self .values [list_id ].value .items [valuelist_id ]
235
+
236
+ # Map Node to Value id.
237
+ self .node_to_value_ids [node ] = value_id
238
+
230
239
def process_call_function_node (self , node ) -> None :
231
240
operator_call_args = []
232
241
@@ -238,12 +247,12 @@ def process_call_function_node(self, node) -> None:
238
247
else :
239
248
function_arg = schema_arg .default_value
240
249
241
- # Create a value for each function argument. If the argument has been
242
- # previously encountered, then use the existing value id.
250
+ # Create a Value for each function argument. If the argument has been
251
+ # previously encountered, then use the existing Value id.
243
252
operator_call_args .append (self .get_or_create_value_for (function_arg ))
244
253
245
254
# Add output node
246
- operator_call_args .append (self .create_tensor_values (node ))
255
+ operator_call_args .append (self .create_node_value (node ))
247
256
248
257
self .chain .append (
249
258
vk_graph_schema .OperatorCall (
@@ -253,7 +262,7 @@ def process_call_function_node(self, node) -> None:
253
262
)
254
263
255
264
def process_getattr_node (self , node : Node ) -> None :
256
- self .create_tensor_values (node )
265
+ self .create_node_value (node )
257
266
258
267
def process_output_node (self , node : Node ) -> None :
259
268
for out_node in node .all_input_nodes :
@@ -269,7 +278,10 @@ def process_node(self, node: Node) -> None:
269
278
if node .op == "placeholder" :
270
279
self .process_placeholder_node (node )
271
280
elif node .op == "call_function" :
272
- self .process_call_function_node (node )
281
+ if node .target == operator .getitem :
282
+ self .process_getitem_node (node )
283
+ else :
284
+ self .process_call_function_node (node )
273
285
elif node .op == "get_attr" :
274
286
self .process_getattr_node (node )
275
287
elif node .op == "output" :
0 commit comments