Skip to content

Commit dad2ba0

Browse files
authored
Use TmpTensor for MatMul op.
Differential Revision: D68924743 Pull Request resolved: #8088
1 parent 00b0ce5 commit dad2ba0

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,25 @@ void add_matmul_optimized_node(
166166
/*passthrough = */ true);
167167

168168
// Ensure mat1 is width packed
169-
ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
169+
TmpTensor mat1_tmp(
170+
&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
171+
ValueRef mat1_W_packed = mat1;
170172
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
171-
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
173+
if (graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
174+
mat1_W_packed = mat1_tmp;
175+
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
176+
}
172177

173178
const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed);
174179

175180
// Ensure mat2 to height packed
176181
ValueRef mat2_packed = mat2;
177182
const utils::GPUMemoryLayout mat2_layout =
178183
mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked;
184+
TmpTensor mat2_tmp(
185+
&graph, graph.sizes_of(mat2), graph.dtype_of(mat2), mat2_layout);
179186
if (graph.estimate_memory_layout_of(mat2) != mat2_layout) {
180-
mat2_packed = graph.add_tensor_like(mat2, mat2_layout);
187+
mat2_packed = mat2_tmp;
181188
viewFn(graph, {mat2, graph.add_none(), mat2_packed});
182189
}
183190

0 commit comments

Comments
 (0)