Skip to content

Commit 1750f0e

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Implement missing Vulkan operators for Parakeet TDT model
Pull Request resolved: #18059 Add missing operators needed for Parakeet TDT model support: - New symint ops: sym_sub, sym_floordiv, sym_mul in SymIntOps.cpp; register operator.floordiv and operator.mul as ephemeral ops in op_registry.py - New tensor ops: bitwise_not (via unary_op shader with uint8 DTYPE), logical_and (alias for bitwise_and dispatch) - Improve _to_copy: expand dtype support to FP_INT_BOOL_T and use pick_io_storage_fn to restrict to CONTIGUOUS_BUFFER for non-fp conversions - Fix where resize: compute output shape via broadcast across all tensor inputs instead of always using the second input's shape - Add symint support to split: use extract_int_or_symint_list instead of get_int_list in resize_split_node and split_with_sizes_copy_default - Mark scalar_tensor as supporting resize ghstack-source-id: 353546692 @exported-using-ghexport Differential Revision: [D95970159](https://our.internmc.facebook.com/intern/diff/D95970159/)
1 parent bf37dc2 commit 1750f0e

File tree

8 files changed

+147
-51
lines changed

8 files changed

+147
-51
lines changed

backends/vulkan/op_registry.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def update_features_impl(op: OpKey):
159159
torch.ops.aten.sym_size.int,
160160
operator.add,
161161
operator.sub,
162+
operator.floordiv,
163+
operator.mul,
162164
operator.lt,
163165
operator.gt,
164166
operator.ge,
@@ -279,6 +281,26 @@ def register_bitwise_and():
279281
)
280282

281283

284+
@update_features(exir_ops.edge.aten.bitwise_not.default)
285+
def register_bitwise_not():
286+
return OpFeatures(
287+
inputs_storage=utils.ANY_STORAGE,
288+
inputs_dtypes=utils.BOOL_T,
289+
supports_resize=True,
290+
supports_highdim=True,
291+
)
292+
293+
294+
@update_features(exir_ops.edge.aten.logical_and.default)
295+
def register_logical_and():
296+
return OpFeatures(
297+
inputs_storage=utils.ANY_STORAGE,
298+
inputs_dtypes=utils.BOOL_T,
299+
supports_resize=True,
300+
supports_highdim=True,
301+
)
302+
303+
282304
# =============================================================================
283305
# BinaryScalarOp.cpp
284306
# =============================================================================
@@ -301,16 +323,22 @@ def register_pow_tensor_scalar():
301323

302324
@update_features(exir_ops.edge.aten._to_copy.default)
303325
def register_to_copy():
304-
def check_to_copy_node(node: torch.fx.Node) -> bool:
305-
# Only single-arg _to_copy is supported
306-
return len(node.args) == 1
326+
def pick_to_copy_storage(
327+
node: torch.fx.Node,
328+
) -> Tuple[utils.TensorRepSet, utils.TensorRepSet]:
329+
in_dtype = node.args[0].meta["val"].dtype # type: ignore[union-attr]
330+
out_dtype = node.meta["val"].dtype
331+
fp_types = {torch.float16, torch.float32}
332+
if in_dtype in fp_types and out_dtype in fp_types:
333+
return utils.ANY_STORAGE, utils.ANY_STORAGE
334+
return utils.CONTIGUOUS_BUFFER, utils.CONTIGUOUS_BUFFER
307335

308336
return OpFeatures(
309337
inputs_storage=utils.ANY_STORAGE,
310-
inputs_dtypes=utils.FP_INT_T,
311-
outputs_dtypes=utils.FP_INT_T,
338+
inputs_dtypes=utils.FP_INT_BOOL_T,
339+
outputs_dtypes=utils.FP_INT_BOOL_T,
312340
supports_resize=True,
313-
are_node_inputs_supported_fn=check_to_copy_node,
341+
pick_io_storage_fn=pick_to_copy_storage,
314342
)
315343

316344

@@ -1336,6 +1364,7 @@ def register_scalar_tensor():
13361364
return OpFeatures(
13371365
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
13381366
inputs_dtypes=utils.FP_INT_T,
1367+
supports_resize=True,
13391368
)
13401369

13411370

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ unary_op:
4646
OPERATOR: leaky_relu(X, A)
4747
- NAME: round
4848
OPERATOR: round(X)
49+
- NAME: bitwise_not_uint8
50+
OPERATOR: 1 - X
51+
DTYPE: uint8

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ REGISTER_OPERATORS {
214214
VK_REGISTER_OP(aten.gt.Tensor, gt);
215215
VK_REGISTER_OP(aten.ge.Tensor, ge);
216216
VK_REGISTER_OP(aten.bitwise_and.Tensor, bitwise_and);
217+
VK_REGISTER_OP(aten.logical_and.default, bitwise_and);
217218
}
218219

219220
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Split.cpp

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,52 +9,13 @@
99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12-
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1312

14-
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1513
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1614

1715
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1816

19-
#include <executorch/backends/vulkan/runtime/utils/StorageUtils.h>
20-
2117
namespace vkcompute {
2218

23-
using utils::GPUMemoryLayout;
24-
using utils::StorageType;
25-
26-
void resize_split_node(
27-
ComputeGraph* graph,
28-
const std::vector<ArgGroup>& args,
29-
const std::vector<ValueRef>& resize_args) {
30-
(void)resize_args;
31-
const ValueRef input = args.at(0).refs.at(0);
32-
const ValueRef split_sizes_ref = args.at(1).refs.at(0);
33-
const ValueRef dim_ref = args.at(2).refs.at(0);
34-
const ValueRef out_list_ref = args.at(3).refs.at(0);
35-
36-
const ValueListPtr out_list = graph->get_value_list(out_list_ref);
37-
const std::vector<int64_t> split_sizes =
38-
*(graph->get_int_list(split_sizes_ref));
39-
const int64_t dim = graph->extract_scalar<int64_t>(dim_ref);
40-
41-
const int64_t input_ndim = graph->dim_of(input);
42-
const DimIndex dim_index = dim < 0 ? static_cast<DimIndex>(dim)
43-
: static_cast<DimIndex>(dim - input_ndim);
44-
45-
std::vector<int64_t> input_sizes = graph->sizes_of(input);
46-
47-
for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) {
48-
const int64_t split_size = split_sizes.at(split_idx);
49-
const ValueRef out_ref = out_list->at(split_idx);
50-
51-
std::vector<int64_t> out_sizes = input_sizes;
52-
out_sizes.at(dim_index) = split_size;
53-
54-
graph->virtual_resize(out_ref, out_sizes);
55-
}
56-
}
57-
5819
void add_split_node(
5920
ComputeGraph& graph,
6021
const ValueRef input,
@@ -125,7 +86,8 @@ void split_with_sizes_copy_default(
12586
ValueRef out_list_ref = args[3];
12687

12788
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
128-
std::vector<int64_t> split_sizes = *(graph.get_int_list(split_sizes_ref));
89+
std::vector<int64_t> split_sizes =
90+
graph.extract_int_or_symint_list(split_sizes_ref);
12991

13092
add_split_with_sizes_node(graph, input, split_sizes, dim, out_list_ref);
13193
}

backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,90 @@ void sym_add(ComputeGraph& graph, const std::vector<ValueRef>& args) {
8181
new ExecuteNode(resize_sym_add_node, args));
8282
}
8383

84+
void sym_sub_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
85+
const ValueRef a = args.at(0);
86+
const ValueRef b = args.at(1);
87+
const ValueRef out = args.at(2);
88+
89+
const int32_t a_val = graph->read_symint(a);
90+
const int32_t b_val = graph->read_symint(b);
91+
const int32_t result = a_val - b_val;
92+
93+
graph->set_symint(out, result);
94+
}
95+
96+
void resize_sym_sub_node(
97+
ComputeGraph* graph,
98+
const std::vector<ArgGroup>& args,
99+
const std::vector<ValueRef>& resize_args) {
100+
(void)args;
101+
sym_sub_impl(graph, resize_args);
102+
}
103+
104+
void sym_sub(ComputeGraph& graph, const std::vector<ValueRef>& args) {
105+
sym_sub_impl(&graph, args);
106+
107+
graph.execute_nodes().emplace_back(
108+
new ExecuteNode(resize_sym_sub_node, args));
109+
}
110+
111+
void sym_floordiv_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
112+
const ValueRef a = args.at(0);
113+
const ValueRef b = args.at(1);
114+
const ValueRef out = args.at(2);
115+
116+
const int32_t a_val = graph->read_symint(a);
117+
const int32_t b_val = graph->read_symint(b);
118+
// Floor division: round towards negative infinity
119+
const int32_t result = (a_val ^ b_val) < 0 && a_val % b_val != 0
120+
? a_val / b_val - 1
121+
: a_val / b_val;
122+
123+
graph->set_symint(out, result);
124+
}
125+
126+
void resize_sym_floordiv_node(
127+
ComputeGraph* graph,
128+
const std::vector<ArgGroup>& args,
129+
const std::vector<ValueRef>& resize_args) {
130+
(void)args;
131+
sym_floordiv_impl(graph, resize_args);
132+
}
133+
134+
void sym_floordiv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
135+
sym_floordiv_impl(&graph, args);
136+
137+
graph.execute_nodes().emplace_back(
138+
new ExecuteNode(resize_sym_floordiv_node, args));
139+
}
140+
141+
void sym_mul_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
142+
const ValueRef a = args.at(0);
143+
const ValueRef b = args.at(1);
144+
const ValueRef out = args.at(2);
145+
146+
const int32_t a_val = graph->read_symint(a);
147+
const int32_t b_val = graph->read_symint(b);
148+
const int32_t result = a_val * b_val;
149+
150+
graph->set_symint(out, result);
151+
}
152+
153+
void resize_sym_mul_node(
154+
ComputeGraph* graph,
155+
const std::vector<ArgGroup>& args,
156+
const std::vector<ValueRef>& resize_args) {
157+
(void)args;
158+
sym_mul_impl(graph, resize_args);
159+
}
160+
161+
void sym_mul(ComputeGraph& graph, const std::vector<ValueRef>& args) {
162+
sym_mul_impl(&graph, args);
163+
164+
graph.execute_nodes().emplace_back(
165+
new ExecuteNode(resize_sym_mul_node, args));
166+
}
167+
84168
void select_as_symint_impl(
85169
ComputeGraph* graph,
86170
const std::vector<ArgGroup>& unused,
@@ -132,6 +216,9 @@ void select_as_symint(ComputeGraph& graph, const std::vector<ValueRef>& args) {
132216
REGISTER_OPERATORS {
133217
VK_REGISTER_OP(sym_size.int, sym_size_int);
134218
VK_REGISTER_OP(add, sym_add);
219+
VK_REGISTER_OP(sub, sym_sub);
220+
VK_REGISTER_OP(floordiv, sym_floordiv);
221+
VK_REGISTER_OP(mul, sym_mul);
135222
VK_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint);
136223
}
137224

backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ bool is_float_type(vkapi::ScalarType dtype) {
3030
}
3131

3232
void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
33-
vkapi::ScalarType in_dtype = graph.dtype_of(in);
34-
vkapi::ScalarType out_dtype = graph.dtype_of(out);
33+
const vkapi::ScalarType in_dtype = graph.dtype_of(in);
34+
const vkapi::ScalarType out_dtype = graph.dtype_of(out);
3535

3636
// Same-dtype or float<->half conversions can use BlitNode
3737
if (in_dtype == out_dtype ||

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ DEFINE_ACTIVATION_FN(hardswish);
158158
DEFINE_ACTIVATION_FN(hardsigmoid);
159159
DEFINE_LEAKY_RELU_FN(leaky_relu);
160160
DEFINE_ACTIVATION_FN(round);
161+
DEFINE_ACTIVATION_FN(bitwise_not);
161162

162163
REGISTER_OPERATORS {
163164
VK_REGISTER_OP(aten.abs.default, abs);
@@ -179,6 +180,7 @@ REGISTER_OPERATORS {
179180
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
180181
VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu);
181182
VK_REGISTER_OP(aten.round.default, round);
183+
VK_REGISTER_OP(aten.bitwise_not.default, bitwise_not);
182184
}
183185

184186
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Where.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,22 @@ void resize_where_node(
2121
const std::vector<ValueRef>& extra_args) {
2222
(void)extra_args;
2323
const ValueRef out = args.at(0).refs.at(0);
24-
const ValueRef self = args.at(1).refs.at(1);
2524

26-
const std::vector<int64_t> self_sizes = graph->sizes_of(self);
27-
graph->virtual_resize(out, self_sizes);
25+
std::vector<int64_t> out_sizes;
26+
for (const ValueRef ref : args.at(1).refs) {
27+
if (!graph->val_is_tensor(ref)) {
28+
continue;
29+
}
30+
const std::vector<int64_t> s = graph->sizes_of(ref);
31+
if (s.size() > out_sizes.size()) {
32+
out_sizes.resize(s.size(), 1);
33+
}
34+
const size_t offset = out_sizes.size() - s.size();
35+
for (size_t i = 0; i < s.size(); i++) {
36+
out_sizes[offset + i] = std::max(out_sizes[offset + i], s[i]);
37+
}
38+
}
39+
graph->virtual_resize(out, out_sizes);
2840
}
2941

3042
void add_where_node(

0 commit comments

Comments
 (0)