Skip to content

[mlir] Extend tests of SymbolTable::replaceAllSymbolUses. #68780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ingomueller-net
Copy link
Contributor

This is a follow-up commit for 4790578@llvm/llvm-project (#68320) that adds more tests. In particular, the tests now check that the limit op itself is not traversed, i.e., symbols in attributes in of the limit op are not renamed.

This is a follow-up commit for 4790578@llvm/llvm-project (llvm#68320)
that adds more tests. In particular, the tests now check that the
`limit` op itself is not traversed, i.e., symbols in attributes in of
the `limit` op are not renamed.
@llvmbot
Copy link
Member

llvmbot commented Oct 11, 2023

@llvm/pr-subscribers-mlir

Author: Ingo Müller (ingomueller-net)

Changes

This is a follow-up commit for 4790578@llvm/llvm-project (#68320) that adds more tests. In particular, the tests now check that the limit op itself is not traversed, i.e., symbols in attributes in of the limit op are not renamed.


Full diff: https://github.com/llvm/llvm-project/pull/68780.diff

1 Files Affected:

  • (modified) mlir/unittests/IR/SymbolTableTest.cpp (+18-12)
diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp
index 5dcec749f0f4259..12f582874d7f549 100644
--- a/mlir/unittests/IR/SymbolTableTest.cpp
+++ b/mlir/unittests/IR/SymbolTableTest.cpp
@@ -28,12 +28,14 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
   void SetUp() override {
     ::test::registerTestDialect(registry);
     context = std::make_unique<MLIRContext>(registry);
+    builder = std::make_unique<OpBuilder>(context.get());
   }
 
   void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
     // Set up IR and find func ops.
     OwningOpRef<ModuleOp> module =
         parseSourceString<ModuleOp>(kInput, context.get());
+    ASSERT_TRUE(module);
     SymbolTable symbolTable(module.get());
     auto opIterator = module->getBody(0)->getOperations().begin();
     auto fooOp = cast<FunctionOpInterface>(opIterator++);
@@ -46,7 +48,7 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
     ASSERT_TRUE(succeeded(res));
     ASSERT_TRUE(succeeded(verify(module.get())));
 
-    // Check that it got renamed.
+    // Check that callee of the call op got renamed.
     bool calleeFound = false;
     fooOp->walk([&](CallOpInterface callOp) {
       StringAttr callee = callOp.getCallableForCallee()
@@ -56,13 +58,19 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
       calleeFound = true;
     });
     EXPECT_TRUE(calleeFound);
+
+    // Check that module attribute did *not* get renamed.
+    auto moduleAttr = (*module)->getAttrOfType<FlatSymbolRefAttr>("test.attr");
+    ASSERT_TRUE(moduleAttr);
+    EXPECT_EQ(moduleAttr.getValue(), StringRef("bar"));
   }
 
   std::unique_ptr<MLIRContext> context;
+  std::unique_ptr<OpBuilder> builder;
 
 private:
   constexpr static llvm::StringLiteral kInput = R"MLIR(
-      module {
+      module attributes { test.attr = @bar } {
         test.conversion_func_op private @foo() {
           "test.conversion_call_op"() { callee=@bar } : () -> ()
           "test.return"() : () -> ()
@@ -81,7 +89,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
     return symbolTable.replaceAllSymbolUses(
-        barOp, StringAttr::get(context.get(), "baz"), module);
+        barOp, builder->getStringAttr("baz"), module);
   });
 }
 
@@ -90,8 +98,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
     return symbolTable.replaceAllSymbolUses(
-        StringAttr::get(context.get(), "bar"),
-        StringAttr::get(context.get(), "baz"), module);
+        builder->getStringAttr("bar"), builder->getStringAttr("baz"), module);
   });
 }
 
@@ -100,7 +107,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
     return symbolTable.replaceAllSymbolUses(
-        barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
+        barOp, builder->getStringAttr("baz"), &module->getRegion(0));
   });
 }
 
@@ -108,9 +115,9 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
   // Symbol as `StringAttr`, rename within module body.
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
-    return symbolTable.replaceAllSymbolUses(
-        StringAttr::get(context.get(), "bar"),
-        StringAttr::get(context.get(), "baz"), &module->getRegion(0));
+    return symbolTable.replaceAllSymbolUses(builder->getStringAttr("bar"),
+                                            builder->getStringAttr("baz"),
+                                            &module->getRegion(0));
   });
 }
 
@@ -119,7 +126,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
     return symbolTable.replaceAllSymbolUses(
-        barOp, StringAttr::get(context.get(), "baz"), fooOp);
+        barOp, builder->getStringAttr("baz"), fooOp);
   });
 }
 
@@ -128,8 +135,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
     return symbolTable.replaceAllSymbolUses(
-        StringAttr::get(context.get(), "bar"),
-        StringAttr::get(context.get(), "baz"), fooOp);
+        builder->getStringAttr("bar"), builder->getStringAttr("baz"), fooOp);
   });
 }
 

@ingomueller-net
Copy link
Contributor Author

I have become unsure about the fix in #68320. In this testcase, the intention of a scope/limit seems to be "this op including the attributes but not the regions." That's the opposite of what I have implemented in the "fix."

What might be the actually deviation from that intention is collectSymbolScopes(Operation *symbol, Operation *limit), which says that it collect[s] all of the symbol scopes from 'symbol' to (inclusive) 'limit'. The fact that limit is inclusive here does not seem to fit the test above. Undoing my "fix" and making that function non-inclusive on limit might be another/a better way of making everything consistent. But I have the feeling that there might still be consequences I don't understand...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants