Skip to content

Commit 01a429c

Browse files
author
Arda Unal
authored
[mlir][mesh] Fix wrong argument passed to targetShardingInUnsplitLastAxis (#95059)
In unsplitLastAxisInResharding, wrong argument was passed when calling targetShardingInUnsplitLastAxis.There weren't any tests to uncover this. I added one in mesh-spmdization.mlir for Linalg and one in resharding-spmdization.mlir for Mesh dialects.
1 parent 1ebda11 commit 01a429c

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
266266
builder.setInsertionPointAfterValue(sourceShard);
267267

268268
MeshShardingAttr targetSharding =
269-
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
269+
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
270270
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
271271
sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
272272
Value allGatherResult = builder.create<AllGatherOp>(

mlir/test/Dialect/Linalg/mesh-spmdization.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,38 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partia
162162
// CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
163163
return %res_shared2 : tensor<4x8xi8>
164164
}
165+
166+
// -----
167+
168+
mesh.mesh @mesh_1d(shape = 4)
169+
170+
// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis
171+
func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
172+
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>,
173+
%in1: tensor<4x6xi8>,
174+
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>,
175+
%in2: tensor<6x8xi8>,
176+
// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
177+
%dps_out: tensor<4x8xi8>
178+
// CHECK-SAME: -> tensor<4x8xi8> {
179+
) -> tensor<4x8xi8> {
180+
%in1_replicated1 = mesh.shard %in1 to <@mesh_1d, [[], []]> : tensor<4x6xi8>
181+
%in1_replicated2 = mesh.shard %in1_replicated1 to <@mesh_1d, [[], []]> annotate_for_users : tensor<4x6xi8>
182+
// CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
183+
%in2_replicated = mesh.shard %in2 to <@mesh_1d, [[], []]> : tensor<6x8xi8>
184+
%in2_sharded = mesh.shard %in2_replicated to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<6x8xi8>
185+
// CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
186+
%dps_out_replicated = mesh.shard %dps_out to <@mesh_1d, [[], []]> : tensor<4x8xi8>
187+
%dps_out_sharded = mesh.shard %dps_out_replicated to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x8xi8>
188+
// CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
189+
// CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
190+
// CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
191+
// CHECK-SAME: -> tensor<4x2xi8>
192+
%res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>)
193+
outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8>
194+
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
195+
%res_sharded = mesh.shard %res to <@mesh_1d, [[], [0]]> : tensor<4x8xi8>
196+
%res_replicated = mesh.shard %res_sharded to <@mesh_1d, [[], []]> annotate_for_users: tensor<4x8xi8>
197+
// CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
198+
return %res_replicated : tensor<4x8xi8>
199+
}

mlir/test/Dialect/Mesh/resharding-spmdization.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ func.func @unshard_static_axis(
9696
return %1 : tensor<10x14xf32>
9797
}
9898

99+
// CHECK-LABEL: func @unshard_static_last_axis
100+
func.func @unshard_static_last_axis(
101+
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
102+
%arg0: tensor<10x14xf32>
103+
) -> tensor<10x14xf32> {
104+
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
105+
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
106+
%0 = mesh.shard %arg0 to <@mesh_1d, [[], [0]]> : tensor<10x14xf32>
107+
%1 = mesh.shard %0 to <@mesh_1d, [[], []]> annotate_for_users : tensor<10x14xf32>
108+
// CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
109+
return %1 : tensor<10x14xf32>
110+
}
111+
99112
// CHECK-LABEL: func @unshard_dynamic_axis
100113
func.func @unshard_dynamic_axis(
101114
// CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>

0 commit comments

Comments
 (0)