Skip to content

Commit 048bf60

Browse files
committed
Add error checking for matrix scope and use parameters
It should be: MatrixA, MatrixB or Accumulator. Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 1010efc commit 048bf60

File tree

11 files changed

+246
-207
lines changed

11 files changed

+246
-207
lines changed

lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,4 +336,23 @@ void SPIRVTypeCooperativeMatrixKHR::decode(std::istream &I) {
336336
Decoder >> Id >> CompType >> Args;
337337
}
338338

339+
void SPIRVTypeCooperativeMatrixKHR::validate() const {
340+
SPIRVEntry::validate();
341+
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
342+
SPIRVConstant *UseConst = static_cast<SPIRVConstant *>(this->getUse());
343+
auto InstName = OpCodeNameMap::map(OC);
344+
uint64_t UseValue = UseConst->getZExtIntValue();
345+
SPVErrLog.checkError(
346+
(UseValue <= internal::InternalJointMatrixUse::Accumulator),
347+
SPIRVEC_InvalidInstruction,
348+
InstName + "\nIncorrect Use parameter, should be MatrixA, MatrixB or "
349+
"Accumulator\n");
350+
SPIRVConstant *ScopeConst = static_cast<SPIRVConstant *>(this->getScope());
351+
uint64_t ScopeValue = ScopeConst->getZExtIntValue();
352+
SPVErrLog.checkError(
353+
(ScopeValue <= ScopeInvocation),
354+
SPIRVEC_InvalidInstruction,
355+
InstName + "\nUnsupported Scope parameter\n");
356+
}
357+
339358
} // namespace SPIRV

lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,9 @@ class SPIRVTypeCooperativeMatrixKHR : public SPIRVType {
11241124
SPIRVType *CompType;
11251125
std::vector<SPIRVValue *> Args;
11261126

1127+
protected:
1128+
void validate() const override;
1129+
11271130
public:
11281131
const static Op OC = OpTypeCooperativeMatrixKHR;
11291132
const static SPIRVWord FixedWC = 7;

test/extensions/INTEL/SPV_INTEL_joint_matrix/bf16_conversion_instructions.ll

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,42 +31,42 @@
3131
; CHECK-SPIRV: CompositeConstruct [[#ShortMatTy]] [[#ShortMat:]]
3232
; CHECK-SPIRV: ConvertBF16ToFINTEL [[#FP32MatTy]] [[#]] [[#ShortMat]]
3333

34-
; CHECK-OCL-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
35-
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z32intel_convert_bfloat16_as_ushortPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]])
36-
; CHECK-OCL-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0)
37-
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z31intel_convert_as_bfloat16_floatPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]])
34+
; CHECK-OCL-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
35+
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z32intel_convert_bfloat16_as_ushortPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_2(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) %[[#FP32Matrix]])
36+
; CHECK-OCL-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructs(i16 0)
37+
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z31intel_convert_as_bfloat16_floatPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_2(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) %[[#ShortMatrix]])
3838

3939

40-
; CHECK-SPV-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
41-
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]])
42-
; CHECK-SPV-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0)
43-
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTELPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]])
40+
; CHECK-SPV-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
41+
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z27__spirv_ConvertFToBF16INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_2(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) %[[#FP32Matrix]])
42+
; CHECK-SPV-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructs(i16 0)
43+
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z27__spirv_ConvertBF16ToFINTELPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_2(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) %[[#ShortMatrix]])
4444

4545

4646
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
4747
target triple = "spir64-unknown-unknown"
4848

4949
define void @convert_f_to_bf() {
5050
entry:
51-
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
52-
%call = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0)
51+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
52+
%call = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) %0)
5353
ret void
5454
}
5555

5656
define void @convert_bf_to_f() {
5757
entry:
58-
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 0)
59-
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %0)
58+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructInt16(i16 0)
59+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) %0)
6060
ret void
6161
}
6262

63-
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float noundef)
63+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructFloat(float noundef)
6464

65-
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 noundef)
65+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructInt16(i16 noundef)
6666

67-
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)
67+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef)
6868

69-
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) noundef)
69+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) noundef)
7070

7171
!llvm.module.flags = !{!0, !1, !2, !3, !4}
7272
!llvm.ident = !{!5}

0 commit comments

Comments
 (0)