Skip to content

Commit c4b9544

Browse files
authored
[Backport to 15] Align translation of OpCooperativeMatrixLengthKHR to match the spec (#2964) (#2997)
`SPV_KHR_cooperative_matrix` extension defines that the only argument accepted in this instruction is `Matrix Type <id>`, not the pointer to an actual matrix.
1 parent 7253825 commit c4b9544

File tree

7 files changed

+28
-8
lines changed

7 files changed

+28
-8
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3501,8 +3501,7 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName,
35013501
Func->addFnAttr(Attribute::Convergent);
35023502
}
35033503
CallInst *Call;
3504-
if (BI->getOpCode() == OpCooperativeMatrixLengthKHR &&
3505-
Ops[0]->getOpCode() == OpTypeCooperativeMatrixKHR) {
3504+
if (BI->getOpCode() == OpCooperativeMatrixLengthKHR) {
35063505
// OpCooperativeMatrixLengthKHR needs special handling as its operand is
35073506
// a Type instead of a Value.
35083507
llvm::Type *MatTy = transType(reinterpret_cast<SPIRVType *>(Ops[0]));

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5626,6 +5626,10 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
56265626
transValue(CI->getArgOperand(2), BB), BB);
56275627
return BM->addStoreInst(transValue(CI->getArgOperand(0), BB), V, {}, BB);
56285628
}
5629+
case OpCooperativeMatrixLengthKHR: {
5630+
return BM->addCooperativeMatrixLengthKHRInst(
5631+
transScavengedType(CI), transType(CI->getArgOperand(0)->getType()), BB);
5632+
}
56295633
default: {
56305634
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
56315635
return BM->addUnaryInst(OC, transType(CI->getType()),

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ class SPIRVModuleImpl : public SPIRVModule {
300300
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) override;
301301
SPIRVTypeCooperativeMatrixKHR *
302302
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) override;
303+
SPIRVInstruction *
304+
addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *,
305+
SPIRVBasicBlock *) override;
303306
SPIRVType *addOpaqueGenericType(Op) override;
304307
SPIRVTypeDeviceEvent *addDeviceEventType() override;
305308
SPIRVTypeQueue *addQueueType() override;
@@ -1056,6 +1059,14 @@ SPIRVModuleImpl::addCooperativeMatrixKHRType(SPIRVType *CompType,
10561059
new SPIRVTypeCooperativeMatrixKHR(this, getId(), CompType, Args));
10571060
}
10581061

1062+
SPIRVInstruction *SPIRVModuleImpl::addCooperativeMatrixLengthKHRInst(
1063+
SPIRVType *RetTy, SPIRVType *MatTy, SPIRVBasicBlock *BB) {
1064+
return addInstruction(
1065+
SPIRVInstTemplateBase::create(OpCooperativeMatrixLengthKHR, RetTy,
1066+
getId(), getVec(MatTy->getId()), BB, this),
1067+
BB);
1068+
}
1069+
10591070
SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) {
10601071
return addType(new SPIRVTypeOpaqueGeneric(TheOpCode, this, getId()));
10611072
}

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ class SPIRVModule {
259259
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
260260
virtual SPIRVTypeCooperativeMatrixKHR *
261261
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
262+
virtual SPIRVInstruction *
263+
addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *,
264+
SPIRVBasicBlock *) = 0;
262265
virtual SPIRVTypeVoid *addVoidType() = 0;
263266
virtual SPIRVType *addOpaqueGenericType(Op) = 0;
264267
virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0;

test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const3]]
2424
; CHECK-SPIRV: CompositeConstruct [[#MatTy1]]
2525
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]]
26+
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]]
2627
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]]
2728
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
2829
; CHECK-SPIRV: CooperativeMatrixStoreKHR
2930

3031

3132
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0)
3233
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) @_Z86__spirv_CooperativeMatrixLoadKHR_RPU3AS144__spirv_CooperativeMatrixKHR__char_0_12_48_3PU3AS4clii
34+
; CHECK-LLVM: call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHRPU3AS144__spirv_CooperativeMatrixKHR__char_0_12_48_3(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3)
3335
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 3) @_Z86__spirv_CooperativeMatrixLoadKHR_RPU3AS144__spirv_CooperativeMatrixKHR__char_2_48_12_3PU3AS4cl
3436
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z34__spirv_CooperativeMatrixMulAddKHRPU3AS144__spirv_CooperativeMatrixKHR__char_0_12_48_3PU3AS144__spirv_CooperativeMatrixKHR__char_2_48_12_3PU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3i(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) %{{.*}}, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 3) %{{.*}}, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3)
3537
; CHECK-LLVM: call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(4) %call.ascast.i.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3)
@@ -105,6 +107,7 @@ for.body.i: ; preds = %for.cond.i
105107
%add.ptr.i96.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i93.i, i64 %conv13.i
106108
%call.ascast.i66.i = addrspacecast ptr addrspace(1) %add.ptr.i96.i to ptr addrspace(4)
107109
%call1.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(4) noundef %call.ascast.i66.i, i64 noundef %_arg_K, i32 noundef 0, i32 noundef 1) #4
110+
%len = tail call spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) %call1.i.i)
108111
%div20.i = mul nsw i32 %k.0.i, 12
109112
%conv21.i = zext i32 %div20.i to i64
110113
%mul23.i = mul i64 %mul22.i, %conv21.i
@@ -136,6 +139,8 @@ _ZZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6
136139
; Function Attrs: convergent
137140
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2
138141

142+
declare dso_local spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) noundef)
143+
139144
; Function Attrs: convergent
140145
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(4) noundef, i64 noundef, i32 noundef, i32 noundef) local_unnamed_addr #2
141146

test/transcoding/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy2:]] [[#Int32Ty]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const3]]
2222
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const3]]
2323

24-
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy1]] [[#Load1:]]
25-
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
26-
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
24+
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy1]]
25+
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy1]]
2726
; CHECK-SPIRV: CooperativeMatrixConstructCheckedINTEL [[#MatTy2]]
2827
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy3]]
2928
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy2]]

test/transcoding/SPV_INTEL_joint_matrix/cooperative_matrix_prefetch.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const3]]
2525

2626
; CHECK-SPIRV: CooperativeMatrixPrefetchINTEL
27-
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy1]] [[#Load1:]]
28-
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
29-
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
27+
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy1]]
28+
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy1]]
3029
; CHECK-SPIRV: CompositeConstruct [[#MatTy2]]
3130
; CHECK-SPIRV: CooperativeMatrixPrefetchINTEL
3231
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]]

0 commit comments

Comments
 (0)