@@ -772,6 +772,135 @@ func.func @warpgroup_mma_128_128_64(
772
772
return
773
773
}
774
774
775
+ // CHECK-LABEL: @warpgroup_mma_store(
776
+ // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
777
+ func.func @warpgroup_mma_store (
778
+ %result1 : !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
779
+ %result2 : !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
780
+ %matrixD: memref <128 x128 xf32 ,3 >) {
781
+ // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
782
+ // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
783
+ // CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
784
+ // CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
785
+ // CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
786
+ // CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
787
+ // CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
788
+ // CHECK: %[[WarpSize:.+]] = llvm.mlir.constant(32 : i32) : i32
789
+
790
+ // ### Store {d0, d1} of each thread ###
791
+
792
+ // CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
793
+ // CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[WarpSize]] : i32
794
+ // CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[WarpSize]] : i32
795
+ // CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32
796
+ // CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32
797
+ // CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32
798
+ // CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]] : i32
799
+ // CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]] : i32
800
+ // CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32
801
+ // CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]] : i32
802
+ // CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]] : i32
803
+ // CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32
804
+ // CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]] : i32
805
+ // CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]] : i32
806
+ // CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index
807
+ // CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
808
+ // CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32
809
+ // CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
810
+ // CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
811
+ // CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
812
+ // CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
813
+ // CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>
814
+
815
+ // ### Store {d2, d3} of each thread ###
816
+
817
+ // CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32
818
+ // CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]] : i32
819
+ // CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]] : i32
820
+ // CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index
821
+ // CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
822
+ // CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32
823
+ // CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
824
+ // CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
825
+ // CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
826
+ // CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
827
+ // CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>
828
+
829
+ // ### Store {d4, d5} of each thread ###
830
+
831
+ // CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32
832
+ // CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]] : i32
833
+ // CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]] : i32
834
+ // CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index
835
+ // CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
836
+ // CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32
837
+ // CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
838
+ // CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
839
+ // CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
840
+ // CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
841
+ // CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>
842
+
843
+ // ### Store {d6, d7} of each thread ###
844
+
845
+ // CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32
846
+ // CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]] : i32
847
+ // CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]] : i32
848
+ // CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index
849
+ // CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
850
+ // CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32
851
+ // CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
852
+ // CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
853
+ // CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
854
+ // CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
855
+ // CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>
856
+
857
+ // Pattern continues similarly 28x times until {... d62, d63}
858
+
859
+ // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
860
+ // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
861
+
862
+ // ### Store {d64, d65} of each thread ###
863
+
864
+ // CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
865
+ // CHECK: %[[S312:.+]] = llvm.mlir.constant(2 : i32) : i32
866
+ // CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
867
+ // CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
868
+ // CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
869
+ // CHECK: %[[WS2:.+]] = llvm.mlir.constant(32 : i32) : i32
870
+ // CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
871
+ // CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[WS2]] : i32
872
+ // CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[WS2]] : i32
873
+ // CHECK: %[[S320:.+]] = llvm.udiv %[[S318]], %[[S311]] : i32
874
+ // CHECK: %[[S321:.+]] = llvm.urem %[[S318]], %[[S311]] : i32
875
+ // CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S312]] : i32
876
+ // CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32
877
+ // CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32
878
+ // CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32
879
+ // CHECK: %[[S326:.+]] = llvm.add %[[S324]], %[[S325]] : i32
880
+ // CHECK: %[[S327:.+]] = llvm.mlir.constant(0 : i32) : i32
881
+ // CHECK: %[[S328:.+]] = llvm.mul %[[S327]], %[[S313]] : i32
882
+ // CHECK: %[[S329:.+]] = llvm.add %[[S326]], %[[S328]] : i32
883
+ // CHECK: %[[S330:.+]] = llvm.mlir.constant(0 : i32) : i32
884
+ // CHECK: %[[S331:.+]] = llvm.mul %[[S330]], %[[S313]] : i32
885
+ // CHECK: %[[S332:.+]] = llvm.add %[[S322]], %[[S331]] : i32
886
+ // CHECK: %[[S333:.+]] = arith.index_cast %[[S329]] : i32 to index
887
+ // CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index
888
+ // CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]] : i32
889
+ // CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index
890
+ // CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0]
891
+ // CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1]
892
+ // CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3>
893
+ // CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3>
894
+
895
+ // Pattern continues similarly 31x times until {... d126, d127}
896
+
897
+ nvgpu.warpgroup.mma.store [%result1 , %result2 ], %matrixD :
898
+ !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
899
+ !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>
900
+ to memref <128 x128 xf32 ,3 >
901
+ return
902
+ }
903
+
775
904
transform.sequence failures (propagate ) {
776
905
^bb1 (%arg1: !transform.any_op ):
777
906
%0 = transform.structured.match ops {[" func.func" ]} in %arg1
0 commit comments