Skip to content

[mlir] [tblgen-to-irdl] Add types to tblgen-to-irdl script #108558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions mlir/test/tblgen-to-irdl/CMathDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,24 @@ class CMath_Op<string mnemonic, list<Trait> traits = []>
def f32Orf64Type : Or<[CPred<"::llvm::isa<::mlir::F32>">,
CPred<"::llvm::isa<::mlir::F64>">]>;

// CHECK: irdl.type @"!complex"
def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
let parameters = (ins f32Orf64Type:$elementType);
let assemblyFormat = "`<` $elementType `>`";
}

// CHECK: irdl.operation @identity {
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
// CHECK-NEXT: %0 = irdl.base @cmath::@"!complex"
// CHECK-NEXT: irdl.results(%0)
// CHECK-NEXT: }
def CMath_IdentityOp : CMath_Op<"identity"> {
let results = (outs CMath_ComplexType:$out);
}

// CHECK: irdl.operation @mul {
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
// CHECK-NEXT: %2 = irdl.base "!cmath.complex"
// CHECK-NEXT: %0 = irdl.base @cmath::@"!complex"
// CHECK-NEXT: %1 = irdl.base @cmath::@"!complex"
// CHECK-NEXT: %2 = irdl.base @cmath::@"!complex"
// CHECK-NEXT: irdl.operands(%0, %1)
// CHECK-NEXT: irdl.results(%2)
// CHECK-NEXT: }
Expand All @@ -45,7 +47,7 @@ def CMath_MulOp : CMath_Op<"mul"> {

// CHECK: irdl.operation @norm {
// CHECK-NEXT: %0 = irdl.any
// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
// CHECK-NEXT: %1 = irdl.base @cmath::@"!complex"
// CHECK-NEXT: irdl.operands(%0)
// CHECK-NEXT: irdl.results(%1)
// CHECK-NEXT: }
Expand Down
17 changes: 10 additions & 7 deletions mlir/test/tblgen-to-irdl/TestDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;

// CHECK: irdl.type @"!singleton_a"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name should just be @singleton_a here.
The reason we had a ! in some places was only in string names (compared to references), to distinguish types and attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can still get name clashes in the symbol names. There are dialects which have operations, types, and attributes share name (I think stablehlo is an example)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait you're right yes!

def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
// CHECK: irdl.type @"!singleton_b"
def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
// CHECK: irdl.type @"!singleton_c"
def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}


Expand All @@ -26,7 +29,7 @@ def Test_AndOp : Test_Op<"and"> {
let arguments = (ins AllOfType<[Test_SingletonAType, AnyType]>:$in);
}
// CHECK-LABEL: irdl.operation @and {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base @test::@"!singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
// CHECK-NEXT: irdl.operands(%[[v2]])
Expand Down Expand Up @@ -79,9 +82,9 @@ def Test_OrOp : Test_Op<"or"> {
let arguments = (ins AnyTypeOf<[Test_SingletonAType, Test_SingletonBType, Test_SingletonCType]>:$in);
}
// CHECK-LABEL: irdl.operation @or {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base @test::@"!singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base @test::@"!singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base @test::@"!singleton_c"
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.operands(%[[v3]])
// CHECK-NEXT: }
Expand Down Expand Up @@ -114,8 +117,8 @@ def Test_VariadicityOp : Test_Op<"variadicity"> {
Test_SingletonCType:$required);
}
// CHECK-LABEL: irdl.operation @variadicity {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base @test::@"!singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base @test::@"!singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base @test::@"!singleton_c"
// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
// CHECK-NEXT: }
36 changes: 36 additions & 0 deletions mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
}

if (predRec.isSubClassOf("TypeDef")) {
auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
if (dialect == selectedDialect) {
std::string combined = ("!" + predRec.getValueAsString("mnemonic")).str();
SmallVector<FlatSymbolRefAttr> nested = {
SymbolRefAttr::get(ctx, combined)};
auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
return op.getOutput();
}
std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
StringAttr::get(ctx, typeName));
Expand Down Expand Up @@ -250,6 +259,12 @@ static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
return opName;
}

/// Returns the name of the type without the dialect prefix.
static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
return opName;
}

/// Extract an operation to IRDL.
irdl::OperationOp createIRDLOperation(OpBuilder &builder,
tblgen::Operator &tblgenOp) {
Expand Down Expand Up @@ -300,6 +315,19 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
return op;
}

irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
MLIRContext *ctx = builder.getContext();
StringRef typeName = getTypeName(tblgenType);
std::string combined = ("!" + typeName).str();

irdl::TypeOp op = builder.create<irdl::TypeOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, combined));

op.getBody().emplaceBlock();

return op;
}

static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
Expand All @@ -322,6 +350,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
// Set insertion point to start of DialectOp.
builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock());

for (const Record *type :
recordKeeper.getAllDerivedDefinitionsIfDefined("TypeDef")) {
tblgen::TypeDef tblgenType(type);
if (tblgenType.getDialect().getName() != selectedDialect)
continue;
createIRDLType(builder, tblgenType);
}

for (const Record *def :
recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
tblgen::Operator tblgenOp(def);
Expand Down
Loading