Skip to content

[AArch64] Refactor implementation of FP8 types (NFC) #118969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -2518,6 +2518,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
bool isFloat32Type() const;
bool isDoubleType() const;
bool isBFloat16Type() const;
bool isMFloat8Type() const;
bool isFloat128Type() const;
bool isIbm128Type() const;
bool isRealType() const; // C99 6.2.5p17 (real floating + integer)
Expand Down Expand Up @@ -8532,6 +8533,10 @@ inline bool Type::isBFloat16Type() const {
return isSpecificBuiltinType(BuiltinType::BFloat16);
}

inline bool Type::isMFloat8Type() const {
return isSpecificBuiltinType(BuiltinType::MFloat8);
}

inline bool Type::isFloat128Type() const {
return isSpecificBuiltinType(BuiltinType::Float128);
}
Expand Down
24 changes: 18 additions & 6 deletions clang/include/clang/Basic/AArch64SVEACLETypes.def
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
// - IsBF true for vector of brain float elements.
//===----------------------------------------------------------------------===//

#ifndef SVE_SCALAR_TYPE
#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
SVE_TYPE(Name, Id, SingletonId)
#endif

#ifndef SVE_VECTOR_TYPE
#define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
SVE_TYPE(Name, Id, SingletonId)
Expand All @@ -72,6 +77,11 @@
SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, false, true)
#endif

#ifndef SVE_VECTOR_TYPE_MFLOAT
#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, false, false)
#endif

#ifndef SVE_VECTOR_TYPE_FLOAT
#define SVE_VECTOR_TYPE_FLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, true, false)
Expand Down Expand Up @@ -125,8 +135,7 @@ SVE_VECTOR_TYPE_FLOAT("__SVFloat64_t", "__SVFloat64_t", SveFloat64, SveFloat64Ty

SVE_VECTOR_TYPE_BFLOAT("__SVBfloat16_t", "__SVBfloat16_t", SveBFloat16, SveBFloat16Ty, 8, 16, 1)

// This is a 8 bits opaque type.
SVE_VECTOR_TYPE_INT("__SVMfloat8_t", "__SVMfloat8_t", SveMFloat8, SveMFloat8Ty, 16, 8, 1, false)
SVE_VECTOR_TYPE_MFLOAT("__SVMfloat8_t", "__SVMfloat8_t", SveMFloat8, SveMFloat8Ty, 16, 8, 1)

//
// x2
Expand All @@ -148,7 +157,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x2_t", "svfloat64x2_t", SveFloat64x2, Sv

SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x2_t", "svbfloat16x2_t", SveBFloat16x2, SveBFloat16x2Ty, 8, 16, 2)

SVE_VECTOR_TYPE_INT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, SveMFloat8x2Ty, 16, 8, 2, false)
SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, SveMFloat8x2Ty, 16, 8, 2)

//
// x3
Expand All @@ -170,7 +179,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x3_t", "svfloat64x3_t", SveFloat64x3, Sv

SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x3_t", "svbfloat16x3_t", SveBFloat16x3, SveBFloat16x3Ty, 8, 16, 3)

SVE_VECTOR_TYPE_INT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, SveMFloat8x3Ty, 16, 8, 3, false)
SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, SveMFloat8x3Ty, 16, 8, 3)

//
// x4
Expand All @@ -192,19 +201,21 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x4_t", "svfloat64x4_t", SveFloat64x4, Sv

SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x4_t", "svbfloat16x4_t", SveBFloat16x4, SveBFloat16x4Ty, 8, 16, 4)

SVE_VECTOR_TYPE_INT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, SveMFloat8x4Ty, 16, 8, 4, false)
SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, SveMFloat8x4Ty, 16, 8, 4)

SVE_PREDICATE_TYPE_ALL("__SVBool_t", "__SVBool_t", SveBool, SveBoolTy, 16, 1)
SVE_PREDICATE_TYPE_ALL("__clang_svboolx2_t", "svboolx2_t", SveBoolx2, SveBoolx2Ty, 16, 2)
SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4Ty, 16, 4)

SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)

AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
SVE_SCALAR_TYPE("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 8)

AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)

#undef SVE_VECTOR_TYPE
#undef SVE_VECTOR_TYPE_MFLOAT
#undef SVE_VECTOR_TYPE_BFLOAT
#undef SVE_VECTOR_TYPE_FLOAT
#undef SVE_VECTOR_TYPE_INT
Expand All @@ -213,4 +224,5 @@ AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloa
#undef SVE_OPAQUE_TYPE
#undef AARCH64_VECTOR_TYPE_MFLOAT
#undef AARCH64_VECTOR_TYPE
#undef SVE_SCALAR_TYPE
#undef SVE_TYPE
37 changes: 30 additions & 7 deletions clang/lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2275,6 +2275,11 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
Width = NumEls * ElBits * NF; \
Align = NumEls * ElBits; \
break;
#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
case BuiltinType::Id: \
Width = Bits; \
Align = Bits; \
break;
#include "clang/Basic/AArch64SVEACLETypes.def"
#define PPC_VECTOR_TYPE(Name, Id, Size) \
case BuiltinType::Id: \
Expand Down Expand Up @@ -4395,15 +4400,18 @@ ASTContext::getBuiltinVectorTypeInfo(const BuiltinType *Ty) const {
ElBits, NF) \
case BuiltinType::Id: \
return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF};
#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
ElBits, NF) \
case BuiltinType::Id: \
return {MFloat8Ty, llvm::ElementCount::getScalable(NumEls), NF};
#define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) \
case BuiltinType::Id: \
return {BoolTy, llvm::ElementCount::getScalable(NumEls), NF};
#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
ElBits, NF) \
case BuiltinType::Id: \
return {getIntTypeForBitwidth(ElBits, false), \
llvm::ElementCount::getFixed(NumEls), NF};
#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
return {MFloat8Ty, llvm::ElementCount::getFixed(NumEls), NF};
#define SVE_TYPE(Name, Id, SingletonId)
#include "clang/Basic/AArch64SVEACLETypes.def"

#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, \
Expand Down Expand Up @@ -4465,11 +4473,16 @@ QualType ASTContext::getScalableVectorType(QualType EltTy, unsigned NumElts,
EltTySize == ElBits && NumElts == (NumEls * NF) && NumFields == 1) { \
return SingletonId; \
}
#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
ElBits, NF) \
if (EltTy->isMFloat8Type() && EltTySize == ElBits && \
NumElts == (NumEls * NF) && NumFields == 1) { \
return SingletonId; \
}
#define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) \
if (EltTy->isBooleanType() && NumElts == (NumEls * NF) && NumFields == 1) \
return SingletonId;
#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)
#define SVE_TYPE(Name, Id, SingletonId)
#include "clang/Basic/AArch64SVEACLETypes.def"
} else if (Target->hasRISCVVTypes()) {
uint64_t EltTySize = getTypeSize(EltTy);
Expand Down Expand Up @@ -12216,8 +12229,15 @@ static QualType DecodeTypeFromStr(const char *&Str, const ASTContext &Context,
RequiresICE, false);
assert(!RequiresICE && "Can't require vector ICE");

// TODO: No way to make AltiVec vectors in builtins yet.
Type = Context.getVectorType(ElementType, NumElements, VectorKind::Generic);
if (ElementType == Context.MFloat8Ty) {
assert((NumElements == 8 || NumElements == 16) &&
"Invalid number of elements");
Type = NumElements == 8 ? Context.MFloat8x8Ty : Context.MFloat8x16Ty;
} else {
// TODO: No way to make AltiVec vectors in builtins yet.
Type =
Context.getVectorType(ElementType, NumElements, VectorKind::Generic);
}
break;
}
case 'E': {
Expand Down Expand Up @@ -12273,6 +12293,9 @@ static QualType DecodeTypeFromStr(const char *&Str, const ASTContext &Context,
case 'p':
Type = Context.getProcessIDType();
break;
case 'm':
Type = Context.MFloat8Ty;
break;
}

// If there are modifiers and if we're allowed to parse them, go for it.
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/AST/ItaniumMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3438,6 +3438,11 @@ void CXXNameMangler::mangleType(const BuiltinType *T) {
type_name = MangledName; \
Out << (type_name == Name ? "u" : "") << type_name.size() << type_name; \
break;
#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
case BuiltinType::Id: \
type_name = MangledName; \
Out << (type_name == Name ? "u" : "") << type_name.size() << type_name; \
break;
#include "clang/Basic/AArch64SVEACLETypes.def"
#define PPC_VECTOR_TYPE(Name, Id, Size) \
case BuiltinType::Id: \
Expand Down
4 changes: 1 addition & 3 deletions clang/lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2527,9 +2527,7 @@ bool Type::isSVESizelessBuiltinType() const {
#define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId) \
case BuiltinType::Id: \
return true;
#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
case BuiltinType::Id: \
return false;
#define SVE_TYPE(Name, Id, SingletonId)
#include "clang/Basic/AArch64SVEACLETypes.def"
default:
return false;
Expand Down
16 changes: 13 additions & 3 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,13 +507,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
case BuiltinType::Id:
#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
case BuiltinType::Id:
#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
#define SVE_TYPE(Name, Id, SingletonId)
#include "clang/Basic/AArch64SVEACLETypes.def"
{
ASTContext::BuiltinVectorTypeInfo Info =
Context.getBuiltinVectorTypeInfo(cast<BuiltinType>(Ty));
auto VTy =
llvm::VectorType::get(ConvertType(Info.ElementType), Info.EC);
// The `__mfp8` type maps to `<1 x i8>` which can't be used to build
// a <N x i8> vector type, hence bypass the call to `ConvertType` for
// the element type and create the vector type directly.
auto *EltTy = Info.ElementType->isMFloat8Type()
? llvm::Type::getInt8Ty(getLLVMContext())
: ConvertType(Info.ElementType);
auto *VTy = llvm::VectorType::get(EltTy, Info.EC);
switch (Info.NumVectors) {
default:
llvm_unreachable("Expected 1, 2, 3 or 4 vectors!");
Expand All @@ -529,6 +534,9 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
}
case BuiltinType::SveCount:
return llvm::TargetExtType::get(getLLVMContext(), "aarch64.svcount");
case BuiltinType::MFloat8:
return llvm::VectorType::get(llvm::Type::getInt8Ty(getLLVMContext()), 1,
false);
#define PPC_VECTOR_TYPE(Name, Id, Size) \
case BuiltinType::Id: \
ResultType = \
Expand Down Expand Up @@ -650,6 +658,8 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
// An ext_vector_type of Bool is really a vector of bits.
llvm::Type *IRElemTy = VT->isExtVectorBoolType()
? llvm::Type::getInt1Ty(getLLVMContext())
: VT->getElementType()->isMFloat8Type()
? llvm::Type::getInt8Ty(getLLVMContext())
: ConvertType(VT->getElementType());
ResultType = llvm::FixedVectorType::get(IRElemTy, VT->getNumElements());
break;
Expand Down
7 changes: 5 additions & 2 deletions clang/lib/CodeGen/Targets/AArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ AArch64ABIInfo::convertFixedToScalableVectorType(const VectorType *VT) const {

case BuiltinType::SChar:
case BuiltinType::UChar:
case BuiltinType::MFloat8:
return llvm::ScalableVectorType::get(
llvm::Type::getInt8Ty(getVMContext()), 16);

Expand Down Expand Up @@ -781,8 +782,10 @@ bool AArch64ABIInfo::passAsPureScalableType(
NPred += Info.NumVectors;
else
NVec += Info.NumVectors;
auto VTy = llvm::ScalableVectorType::get(CGT.ConvertType(Info.ElementType),
Info.EC.getKnownMinValue());
llvm::Type *EltTy = Info.ElementType->isMFloat8Type()
? llvm::Type::getInt8Ty(getVMContext())
: CGT.ConvertType(Info.ElementType);
auto *VTy = llvm::ScalableVectorType::get(EltTy, Info.EC.getKnownMinValue());

if (CoerceToSeq.size() + Info.NumVectors > 12)
return false;
Expand Down
4 changes: 2 additions & 2 deletions clang/utils/TableGen/SveEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,15 +448,15 @@ std::string SVEType::builtinBaseType() const {
case TypeKind::PredicatePattern:
return "i";
case TypeKind::Fpm:
return "Wi";
return "UWi";
case TypeKind::Predicate:
return "b";
case TypeKind::BFloat16:
assert(ElementBitwidth == 16 && "Invalid BFloat16!");
return "y";
case TypeKind::MFloat8:
assert(ElementBitwidth == 8 && "Invalid MFloat8!");
return "c";
return "m";
case TypeKind::Float:
switch (ElementBitwidth) {
case 16:
Expand Down
Loading