@@ -166,18 +166,25 @@ void add_matmul_optimized_node(
166
166
/* passthrough = */ true );
167
167
168
168
// Ensure mat1 is width packed
169
- ValueRef mat1_W_packed = graph.add_tensor_like (mat1, utils::kWidthPacked );
169
+ TmpTensor mat1_tmp (
170
+ &graph, graph.sizes_of (mat1), graph.dtype_of (mat1), utils::kWidthPacked );
171
+ ValueRef mat1_W_packed = mat1;
170
172
auto viewFn = VK_GET_OP_FN (" aten.view_copy.default" );
171
- viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
173
+ if (graph.packed_dim_of (mat1) != WHCN::kWidthDim ) {
174
+ mat1_W_packed = mat1_tmp;
175
+ viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
176
+ }
172
177
173
178
const bool mat2_is_transposed_val = graph.get_bool (mat2_is_transposed);
174
179
175
180
// Ensure mat2 to height packed
176
181
ValueRef mat2_packed = mat2;
177
182
const utils::GPUMemoryLayout mat2_layout =
178
183
mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked ;
184
+ TmpTensor mat2_tmp (
185
+ &graph, graph.sizes_of (mat2), graph.dtype_of (mat2), mat2_layout);
179
186
if (graph.estimate_memory_layout_of (mat2) != mat2_layout) {
180
- mat2_packed = graph. add_tensor_like (mat2, mat2_layout) ;
187
+ mat2_packed = mat2_tmp ;
181
188
viewFn (graph, {mat2, graph.add_none (), mat2_packed});
182
189
}
183
190
0 commit comments