Skip to content

Commit 9e029e8

Browse files
ChuanqiXu9lanza
authored andcommitted
[CIR][CodeGen] Handle the case of 'case' after label statement after 'case' (#879)
Motivation example: ``` extern "C" void action1(); extern "C" void action2(); extern "C" void case_follow_label(int v) { switch (v) { case 1: label: case 2: action1(); break; default: action2(); goto label; } } ``` When we compile it, we will meet: ``` case Stmt::CaseStmtClass: case Stmt::DefaultStmtClass: assert(0 && "Should not get here, currently handled directly from SwitchStmt"); break; ``` in `buildStmt`. The cause is clear. We call `buildStmt` when we build the label stmt. To solve this, I think we should be able to build case stmt in buildStmt. But the new problem is, we need to pass the information like caseAttr and condType. So I tried to add such informations in CIRGenFunction as data member.
1 parent 4ae5f1c commit 9e029e8

File tree

3 files changed

+77
-20
lines changed

3 files changed

+77
-20
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.h

+9-6
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,13 @@ class CIRGenFunction : public CIRGenTypeCache {
478478
// applies to. nullptr if there is no 'musttail' on the current statement.
479479
const clang::CallExpr *MustTailCall = nullptr;
480480

481+
/// The attributes of cases collected during emitting the body of a switch
482+
/// stmt.
483+
llvm::SmallVector<llvm::SmallVector<mlir::Attribute, 4>, 2> caseAttrsStack;
484+
485+
/// The type of the condition for the emitting switch statement.
486+
llvm::SmallVector<mlir::Type, 2> condTypeStack;
487+
481488
clang::ASTContext &getContext() const;
482489

483490
CIRGenBuilderTy &getBuilder() { return builder; }
@@ -1210,13 +1217,9 @@ class CIRGenFunction : public CIRGenTypeCache {
12101217
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
12111218
SmallVector<mlir::Attribute, 4> &caseAttrs);
12121219

1213-
mlir::LogicalResult
1214-
buildSwitchCase(const clang::SwitchCase &S, mlir::Type condType,
1215-
SmallVector<mlir::Attribute, 4> &caseAttrs);
1220+
mlir::LogicalResult buildSwitchCase(const clang::SwitchCase &S);
12161221

1217-
mlir::LogicalResult
1218-
buildSwitchBody(const clang::Stmt *S, mlir::Type condType,
1219-
SmallVector<mlir::Attribute, 4> &caseAttrs);
1222+
mlir::LogicalResult buildSwitchBody(const clang::Stmt *S);
12201223

12211224
mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
12221225
const CIRGenFunctionInfo &FnInfo);

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

+20-14
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,7 @@ mlir::LogicalResult CIRGenFunction::buildSimpleStmt(const Stmt *S,
303303

304304
case Stmt::CaseStmtClass:
305305
case Stmt::DefaultStmtClass:
306-
assert(0 &&
307-
"Should not get here, currently handled directly from SwitchStmt");
306+
return buildSwitchCase(cast<SwitchCase>(*S));
308307
break;
309308

310309
case Stmt::BreakStmtClass:
@@ -715,14 +714,19 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
715714
return buildCaseDefaultCascade(&S, condType, caseAttrs);
716715
}
717716

718-
mlir::LogicalResult
719-
CIRGenFunction::buildSwitchCase(const SwitchCase &S, mlir::Type condType,
720-
SmallVector<mlir::Attribute, 4> &caseAttrs) {
717+
mlir::LogicalResult CIRGenFunction::buildSwitchCase(const SwitchCase &S) {
718+
assert(!caseAttrsStack.empty() &&
719+
"build switch case without seeting case attrs");
720+
assert(!condTypeStack.empty() &&
721+
"build switch case without specifying the type of the condition");
722+
721723
if (S.getStmtClass() == Stmt::CaseStmtClass)
722-
return buildCaseStmt(cast<CaseStmt>(S), condType, caseAttrs);
724+
return buildCaseStmt(cast<CaseStmt>(S), condTypeStack.back(),
725+
caseAttrsStack.back());
723726

724727
if (S.getStmtClass() == Stmt::DefaultStmtClass)
725-
return buildDefaultStmt(cast<DefaultStmt>(S), condType, caseAttrs);
728+
return buildDefaultStmt(cast<DefaultStmt>(S), condTypeStack.back(),
729+
caseAttrsStack.back());
726730

727731
llvm_unreachable("expect case or default stmt");
728732
}
@@ -987,15 +991,13 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
987991
return mlir::success();
988992
}
989993

990-
mlir::LogicalResult CIRGenFunction::buildSwitchBody(
991-
const Stmt *S, mlir::Type condType,
992-
llvm::SmallVector<mlir::Attribute, 4> &caseAttrs) {
994+
mlir::LogicalResult CIRGenFunction::buildSwitchBody(const Stmt *S) {
993995
if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
994996
mlir::Block *lastCaseBlock = nullptr;
995997
auto res = mlir::success();
996998
for (auto *c : compoundStmt->body()) {
997999
if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
998-
res = buildSwitchCase(*switchCase, condType, caseAttrs);
1000+
res = buildSwitchCase(*switchCase);
9991001
lastCaseBlock = builder.getBlock();
10001002
} else if (lastCaseBlock) {
10011003
// This means it's a random stmt following up a case, just
@@ -1045,12 +1047,16 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
10451047
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
10461048
currLexScope->setAsSwitch();
10471049

1048-
llvm::SmallVector<mlir::Attribute, 4> caseAttrs;
1050+
caseAttrsStack.push_back({});
1051+
condTypeStack.push_back(condV.getType());
10491052

1050-
res = buildSwitchBody(S.getBody(), condV.getType(), caseAttrs);
1053+
res = buildSwitchBody(S.getBody());
10511054

10521055
os.addRegions(currLexScope->getSwitchRegions());
1053-
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
1056+
os.addAttribute("cases", builder.getArrayAttr(caseAttrsStack.back()));
1057+
1058+
caseAttrsStack.pop_back();
1059+
condTypeStack.pop_back();
10541060
});
10551061

10561062
if (res.failed())

clang/test/CIR/CodeGen/goto.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,51 @@ extern "C" void multiple_non_case(int v) {
310310
// NOFLAT: cir.label
311311
// NOFLAT: cir.call @action2()
312312
// NOFLAT: cir.break
313+
314+
extern "C" void case_follow_label(int v) {
315+
switch (v) {
316+
case 1:
317+
label:
318+
case 2:
319+
action1();
320+
break;
321+
default:
322+
action2();
323+
goto label;
324+
}
325+
}
326+
327+
// NOFLAT: cir.func @case_follow_label
328+
// NOFLAT: cir.switch
329+
// NOFLAT: case (equal, 1)
330+
// NOFLAT: cir.label "label"
331+
// NOFLAT: cir.yield
332+
// NOFLAT: case (equal, 2)
333+
// NOFLAT: cir.call @action1()
334+
// NOFLAT: cir.break
335+
// NOFLAT: case (default)
336+
// NOFLAT: cir.call @action2()
337+
// NOFLAT: cir.goto "label"
338+
339+
extern "C" void default_follow_label(int v) {
340+
switch (v) {
341+
case 1:
342+
case 2:
343+
action1();
344+
break;
345+
label:
346+
default:
347+
action2();
348+
goto label;
349+
}
350+
}
351+
352+
// NOFLAT: cir.func @default_follow_label
353+
// NOFLAT: cir.switch
354+
// NOFLAT: case (anyof, [1, 2] : !s32i)
355+
// NOFLAT: cir.call @action1()
356+
// NOFLAT: cir.break
357+
// NOFLAT: cir.label "label"
358+
// NOFLAT: case (default)
359+
// NOFLAT: cir.call @action2()
360+
// NOFLAT: cir.goto "label"

0 commit comments

Comments
 (0)