@@ -162,3 +162,38 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partia
162
162
// CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
163
163
return %res_shared2 : tensor <4 x8 xi8 >
164
164
}
165
+
166
+ // -----
167
+
168
+ mesh.mesh @mesh_1d (shape = 4 )
169
+
170
+ // CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis
171
+ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis (
172
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>,
173
+ %in1: tensor <4 x6 xi8 >,
174
+ // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>,
175
+ %in2: tensor <6 x8 xi8 >,
176
+ // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
177
+ %dps_out: tensor <4 x8 xi8 >
178
+ // CHECK-SAME: -> tensor<4x8xi8> {
179
+ ) -> tensor <4 x8 xi8 > {
180
+ %in1_replicated1 = mesh.shard %in1 to <@mesh_1d , [[], []]> : tensor <4 x6 xi8 >
181
+ %in1_replicated2 = mesh.shard %in1_replicated1 to <@mesh_1d , [[], []]> annotate_for_users : tensor <4 x6 xi8 >
182
+ // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
183
+ %in2_replicated = mesh.shard %in2 to <@mesh_1d , [[], []]> : tensor <6 x8 xi8 >
184
+ %in2_sharded = mesh.shard %in2_replicated to <@mesh_1d , [[], [0 ]]> annotate_for_users : tensor <6 x8 xi8 >
185
+ // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
186
+ %dps_out_replicated = mesh.shard %dps_out to <@mesh_1d , [[], []]> : tensor <4 x8 xi8 >
187
+ %dps_out_sharded = mesh.shard %dps_out_replicated to <@mesh_1d , [[], [0 ]]> annotate_for_users : tensor <4 x8 xi8 >
188
+ // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
189
+ // CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
190
+ // CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
191
+ // CHECK-SAME: -> tensor<4x2xi8>
192
+ %res = linalg.matmul ins (%in1_replicated2 , %in2_sharded : tensor <4 x6 xi8 >, tensor <6 x8 xi8 >)
193
+ outs (%dps_out_sharded : tensor <4 x8 xi8 >) -> tensor <4 x8 xi8 >
194
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
195
+ %res_sharded = mesh.shard %res to <@mesh_1d , [[], [0 ]]> : tensor <4 x8 xi8 >
196
+ %res_replicated = mesh.shard %res_sharded to <@mesh_1d , [[], []]> annotate_for_users : tensor <4 x8 xi8 >
197
+ // CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
198
+ return %res_replicated : tensor <4 x8 xi8 >
199
+ }
0 commit comments