@@ -371,6 +371,62 @@ module attributes {transform.with_named_sequence} {
371
371
}
372
372
}
373
373
374
+
375
+ // -----
376
+
377
+ // CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 32)
378
+ #map = affine_map <(d0 ) -> (d0 * 32 )>
379
+ #map1 = affine_map <(d0 , d1 ) -> (d0 , d1 )>
380
+ module {
381
+ // CHECK: func.func @loop_sibling_fusion(%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}
382
+ func.func @loop_sibling_fusion (%arg0: tensor <128 xf32 >, %arg1: tensor <128 x128 xf16 >, %arg2: tensor <128 x64 xf32 >, %arg3: tensor <128 x128 xf32 >) -> (tensor <128 xf32 >, tensor <128 x128 xf16 >) {
383
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<128x128xf16>
384
+ // CHECK-NEXT: %[[RESULTS:.*]]:2 = scf.forall (%[[I:.*]]) in (4) shared_outs(%[[S1:.*]] = %[[ARG0]], %[[S2:.*]] = %[[ARG1]]) -> (tensor<128xf32>, tensor<128x128xf16>) {
385
+ // CHECK-NEXT: %[[IDX:.*]] = affine.apply #[[$MAP]](%[[I]])
386
+ // CHECK-NEXT: %[[SLICE0:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
387
+ // CHECK-NEXT: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
388
+ // CHECK-NEXT: %[[SLICE2:.*]] = tensor.extract_slice %[[EMPTY]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
389
+ // CHECK-NEXT: %[[GENERIC:.*]] = linalg.generic {{.*}} ins(%[[SLICE1]] : {{.*}}) outs(%[[SLICE2]] : {{.*}})
390
+ // CHECK: scf.forall.in_parallel {
391
+ // CHECK-NEXT: tensor.parallel_insert_slice %[[SLICE0]] into %[[S1]][%[[IDX]]] [32] [1] : tensor<32xf32> into tensor<128xf32>
392
+ // CHECK-NEXT: tensor.parallel_insert_slice %[[GENERIC]] into %[[S2]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
393
+ // CHECK-NEXT: }
394
+ // CHECK-NEXT: } {mapping = [#gpu.warp<linear_dim_0>]}
395
+ // CHECK-NEXT: return %[[RESULTS]]#0, %[[RESULTS]]#1
396
+ %0 = scf.forall (%arg4 ) in (4 ) shared_outs (%arg5 = %arg0 ) -> (tensor <128 xf32 >) {
397
+ %3 = affine.apply #map (%arg4 )
398
+ %extracted_slice = tensor.extract_slice %arg3 [%3 , 0 ] [32 , 1 ] [1 , 1 ] : tensor <128 x128 xf32 > to tensor <32 xf32 >
399
+ scf.forall.in_parallel {
400
+ tensor.parallel_insert_slice %extracted_slice into %arg5 [%3 ] [32 ] [1 ] : tensor <32 xf32 > into tensor <128 xf32 >
401
+ }
402
+ } {mapping = [#gpu.warp <linear_dim_0 >]}
403
+ %1 = tensor.empty () : tensor <128 x128 xf16 >
404
+ %2 = scf.forall (%arg4 ) in (4 ) shared_outs (%arg5 = %arg1 ) -> (tensor <128 x128 xf16 >) {
405
+ %3 = affine.apply #map (%arg4 )
406
+ %extracted_slice = tensor.extract_slice %arg3 [%3 , 0 ] [32 , 128 ] [1 , 1 ] : tensor <128 x128 xf32 > to tensor <32 x128 xf32 >
407
+ %extracted_slice_0 = tensor.extract_slice %1 [%3 , 0 ] [32 , 128 ] [1 , 1 ] : tensor <128 x128 xf16 > to tensor <32 x128 xf16 >
408
+ %4 = linalg.generic {index ing_maps = [#map1 , #map1 ], iterator_types = [" parallel" , " parallel" ]} ins (%extracted_slice : tensor <32 x128 xf32 >) outs (%extracted_slice_0 : tensor <32 x128 xf16 >) {
409
+ ^bb0 (%in: f32 , %out: f16 ):
410
+ %5 = arith.truncf %in : f32 to f16
411
+ linalg.yield %5 : f16
412
+ } -> tensor <32 x128 xf16 >
413
+ scf.forall.in_parallel {
414
+ tensor.parallel_insert_slice %4 into %arg5 [%3 , 0 ] [32 , 128 ] [1 , 1 ] : tensor <32 x128 xf16 > into tensor <128 x128 xf16 >
415
+ }
416
+ } {mapping = [#gpu.warp <linear_dim_0 >]}
417
+ return %0 , %2 : tensor <128 xf32 >, tensor <128 x128 xf16 >
418
+ }
419
+ }
420
+
421
+ module attributes { transform.with_named_sequence } {
422
+ transform.named_sequence @__transform_main (%root: !transform.any_op ) {
423
+ %loops = transform.structured.match ops {[" scf.forall" ]} in %root : (!transform.any_op ) -> !transform.any_op
424
+ %loop1 , %loop2 = transform.split_handle %loops : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
425
+ %loop3 = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op , !transform.any_op ) -> !transform.any_op
426
+ transform.yield
427
+ }
428
+ }
429
+
374
430
// -----
375
431
376
432
func.func @source_for_uses_result_of_target_for_err (%A: tensor <128 xf32 >, %B: tensor <128 xf32 >) -> (tensor <128 xf32 >, tensor <128 xf32 >) {
0 commit comments