Skip to content

Commit 5c50ed9

Browse files
keryelllanza
authored andcommitted
[CIR] Remove the !cir.void return type for functions returning void (#1203)
C/C++ functions returning void had an explicit !cir.void return type while not having any returned value, which was breaking a lot of MLIR invariants when the CIR dialect is used in a greater context, for example with the inliner. Now, a C/C++ function returning void has not return type and no return values, which does not break the MLIR invariant about the same number of return types and returned values. This change keeps the same parsing/pretty-printed syntax as before for compatibility.
1 parent a7a0268 commit 5c50ed9

File tree

18 files changed

+194
-46
lines changed

18 files changed

+194
-46
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+8-5
Original file line numberDiff line numberDiff line change
@@ -3474,8 +3474,6 @@ def FuncOp : CIR_Op<"func", [
34743474
/// Returns the results types that the callable region produces when
34753475
/// executed.
34763476
llvm::ArrayRef<mlir::Type> getCallableResults() {
3477-
if (::llvm::isa<cir::VoidType>(getFunctionType().getReturnType()))
3478-
return {};
34793477
return getFunctionType().getReturnTypes();
34803478
}
34813479

@@ -3492,10 +3490,15 @@ def FuncOp : CIR_Op<"func", [
34923490
}
34933491

34943492
/// Returns the argument types of this function.
3495-
llvm::ArrayRef<mlir::Type> getArgumentTypes() { return getFunctionType().getInputs(); }
3493+
llvm::ArrayRef<mlir::Type> getArgumentTypes() {
3494+
return getFunctionType().getInputs();
3495+
}
34963496

3497-
/// Returns the result types of this function.
3498-
llvm::ArrayRef<mlir::Type> getResultTypes() { return getFunctionType().getReturnTypes(); }
3497+
/// Returns 0 or 1 result type of this function (0 in the case of a function
3498+
/// returing void)
3499+
llvm::ArrayRef<mlir::Type> getResultTypes() {
3500+
return getFunctionType().getReturnTypes();
3501+
}
34993502

35003503
/// Hook for OpTrait::FunctionOpInterfaceTrait, called after verifying that
35013504
/// the 'type' attribute is present and checks if it holds a function type.

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+12-7
Original file line numberDiff line numberDiff line change
@@ -379,22 +379,27 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
379379

380380
```mlir
381381
!cir.func<!bool ()>
382+
!cir.func<!cir.void ()>
382383
!cir.func<!s32i (!s8i, !s8i)>
383384
!cir.func<!s32i (!s32i, ...)>
384385
```
385386
}];
386387

387-
let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, "mlir::Type":$returnType,
388+
let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, ArrayRefParameter<"mlir::Type">:$returnTypes,
388389
"bool":$varArg);
389390
let assemblyFormat = [{
390-
`<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
391+
`<` custom<FuncType>($returnTypes, $inputs, $varArg) `>`
391392
}];
392393

393394
let builders = [
395+
// Construct with an actual return type or explicit !cir.void
394396
TypeBuilderWithInferredContext<(ins
395397
"llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
396398
CArg<"bool", "false">:$isVarArg), [{
397-
return $_get(returnType.getContext(), inputs, returnType, isVarArg);
399+
return $_get(returnType.getContext(), inputs,
400+
::mlir::isa<::cir::VoidType>(returnType) ? llvm::ArrayRef<mlir::Type>{}
401+
: llvm::ArrayRef{returnType},
402+
isVarArg);
398403
}]>
399404
];
400405

@@ -408,11 +413,11 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
408413
/// Returns the number of arguments to the function.
409414
unsigned getNumInputs() const { return getInputs().size(); }
410415

411-
/// Returns the result type of the function as an ArrayRef, enabling better
412-
/// integration with generic MLIR utilities.
413-
llvm::ArrayRef<mlir::Type> getReturnTypes() const;
416+
/// Returns the result type of the function as an actual return type or
417+
/// explicit !cir.void
418+
mlir::Type getReturnType() const;
414419

415-
/// Returns whether the function is returns void.
420+
/// Returns whether the function returns void.
416421
bool isVoid() const;
417422

418423
/// Returns a clone of this function type with the given argument

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ mlir::Type CIRGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
271271
assert(QFT.isCanonical());
272272
const Type *Ty = QFT.getTypePtr();
273273
const FunctionType *FT = cast<FunctionType>(QFT.getTypePtr());
274-
// First, check whether we can build the full fucntion type. If the function
274+
// First, check whether we can build the full function type. If the function
275275
// type depends on an incomplete type (e.g. a struct or enum), we cannot lower
276276
// the function type.
277277
assert(isFuncTypeConvertible(FT) && "NYI");

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+30-10
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "mlir/IR/OpImplementation.h"
3636
#include "mlir/IR/StorageUniquerSupport.h"
3737
#include "mlir/IR/TypeUtilities.h"
38+
#include "mlir/Interfaces/CallInterfaces.h"
3839
#include "mlir/Interfaces/DataLayoutInterfaces.h"
3940
#include "mlir/Interfaces/FunctionImplementation.h"
4041
#include "mlir/Interfaces/InferTypeOpInterface.h"
@@ -2224,6 +2225,26 @@ void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
22242225
getResAttrsAttrName(result.name));
22252226
}
22262227

2228+
// A specific version of function_interface_impl::parseFunctionSignature able to
2229+
// handle the "-> !void" special fake return type.
2230+
static ParseResult
2231+
parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
2232+
SmallVectorImpl<OpAsmParser::Argument> &arguments,
2233+
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
2234+
SmallVectorImpl<DictionaryAttr> &resultAttrs) {
2235+
if (function_interface_impl::parseFunctionArgumentList(parser, allowVariadic,
2236+
arguments, isVariadic))
2237+
return failure();
2238+
if (succeeded(parser.parseOptionalArrow())) {
2239+
if (parser.parseOptionalExclamationKeyword("!void").succeeded())
2240+
// This is just an empty return type and attribute.
2241+
return success();
2242+
return call_interface_impl::parseFunctionResultList(parser, resultTypes,
2243+
resultAttrs);
2244+
}
2245+
return success();
2246+
}
2247+
22272248
ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
22282249
llvm::SMLoc loc = parser.getCurrentLocation();
22292250

@@ -2284,9 +2305,8 @@ ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
22842305

22852306
// Parse the function signature.
22862307
bool isVariadic = false;
2287-
if (function_interface_impl::parseFunctionSignatureWithArguments(
2288-
parser, /*allowVariadic=*/false, arguments, isVariadic, resultTypes,
2289-
resultAttrs))
2308+
if (parseFunctionSignature(parser, /*allowVariadic=*/true, arguments,
2309+
isVariadic, resultTypes, resultAttrs))
22902310
return failure();
22912311

22922312
for (auto &arg : arguments)
@@ -2489,13 +2509,8 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
24892509
p.printSymbolName(getSymName());
24902510
auto fnType = getFunctionType();
24912511
llvm::SmallVector<Type, 1> resultTypes;
2492-
if (!fnType.isVoid())
2493-
function_interface_impl::printFunctionSignature(
2494-
p, *this, fnType.getInputs(), fnType.isVarArg(),
2495-
fnType.getReturnTypes());
2496-
else
2497-
function_interface_impl::printFunctionSignature(
2498-
p, *this, fnType.getInputs(), fnType.isVarArg(), {});
2512+
function_interface_impl::printFunctionSignature(
2513+
p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
24992514

25002515
if (mlir::ArrayAttr annotations = getAnnotationsAttr()) {
25012516
p << ' ';
@@ -2564,6 +2579,11 @@ LogicalResult cir::FuncOp::verifyType() {
25642579
if (!getNoProto() && type.isVarArg() && type.getNumInputs() == 0)
25652580
return emitError()
25662581
<< "prototyped function must have at least one non-variadic input";
2582+
if (auto rt = type.getReturnTypes();
2583+
!rt.empty() && mlir::isa<cir::VoidType>(rt.front()))
2584+
return emitOpError("The return type for a function returning void should "
2585+
"be empty instead of an explicit !cir.void");
2586+
25672587
return success();
25682588
}
25692589

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

+81-12
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/ADT/TypeSwitch.h"
3434
#include "llvm/Support/ErrorHandling.h"
3535
#include "llvm/Support/MathExtras.h"
36+
#include <cassert>
3637
#include <optional>
3738

3839
using cir::MissingFeatures;
@@ -42,13 +43,16 @@ using cir::MissingFeatures;
4243
//===----------------------------------------------------------------------===//
4344

4445
static mlir::ParseResult
45-
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
46-
bool &isVarArg);
47-
static void printFuncTypeArgs(mlir::AsmPrinter &p,
48-
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
46+
parseFuncType(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &returnTypes,
47+
llvm::SmallVector<mlir::Type> &params, bool &isVarArg);
48+
49+
static void printFuncType(mlir::AsmPrinter &p,
50+
mlir::ArrayRef<mlir::Type> returnTypes,
51+
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
4952

5053
static mlir::ParseResult parsePointerAddrSpace(mlir::AsmParser &p,
5154
mlir::Attribute &addrSpaceAttr);
55+
5256
static void printPointerAddrSpace(mlir::AsmPrinter &p,
5357
mlir::Attribute addrSpaceAttr);
5458

@@ -813,9 +817,46 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
813817
return get(llvm::to_vector(inputs), results[0], isVarArg());
814818
}
815819

816-
mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
817-
llvm::SmallVector<mlir::Type> &params,
818-
bool &isVarArg) {
820+
// A special parser is needed for function returning void to consume the "!void"
821+
// returned type in the case there is no alias defined.
822+
static mlir::ParseResult
823+
parseFuncTypeReturn(mlir::AsmParser &p,
824+
llvm::SmallVector<mlir::Type> &returnTypes) {
825+
if (p.parseOptionalExclamationKeyword("!void").succeeded())
826+
// !void means no return type.
827+
return p.parseLParen();
828+
if (succeeded(p.parseOptionalLParen()))
829+
// If we have already a '(', the function has no return type
830+
return mlir::success();
831+
832+
mlir::Type type;
833+
auto result = p.parseOptionalType(type);
834+
if (!result.has_value())
835+
return mlir::failure();
836+
if (failed(*result) || isa<cir::VoidType>(type))
837+
// No return type specified.
838+
return p.parseLParen();
839+
// Otherwise use the actual type.
840+
returnTypes.push_back(type);
841+
return p.parseLParen();
842+
}
843+
844+
// A special pretty-printer for function returning void to emit a "!void"
845+
// returned type. Note that there is no real type used here since it does not
846+
// appear in the IR and thus the alias might not be defined and cannot be
847+
// referred to. This is why this is a pure syntactic-sugar string which is used.
848+
static void printFuncTypeReturn(mlir::AsmPrinter &p,
849+
mlir::ArrayRef<mlir::Type> returnTypes) {
850+
if (returnTypes.empty())
851+
// Pretty-print no return type as "!void"
852+
p << "!void ";
853+
else
854+
p << returnTypes << ' ';
855+
}
856+
857+
static mlir::ParseResult
858+
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
859+
bool &isVarArg) {
819860
isVarArg = false;
820861
// `(` `)`
821862
if (succeeded(p.parseOptionalRParen()))
@@ -845,8 +886,10 @@ mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
845886
return p.parseRParen();
846887
}
847888

848-
void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
849-
bool isVarArg) {
889+
static void printFuncTypeArgs(mlir::AsmPrinter &p,
890+
mlir::ArrayRef<mlir::Type> params,
891+
bool isVarArg) {
892+
p << '(';
850893
llvm::interleaveComma(params, p,
851894
[&p](mlir::Type type) { p.printType(type); });
852895
if (isVarArg) {
@@ -857,11 +900,37 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
857900
p << ')';
858901
}
859902

860-
llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
861-
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
903+
static mlir::ParseResult
904+
parseFuncType(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &returnTypes,
905+
llvm::SmallVector<mlir::Type> &params, bool &isVarArg) {
906+
if (failed(parseFuncTypeReturn(p, returnTypes)))
907+
return failure();
908+
return parseFuncTypeArgs(p, params, isVarArg);
909+
}
910+
911+
static void printFuncType(mlir::AsmPrinter &p,
912+
mlir::ArrayRef<mlir::Type> returnTypes,
913+
mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
914+
printFuncTypeReturn(p, returnTypes);
915+
printFuncTypeArgs(p, params, isVarArg);
862916
}
863917

864-
bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
918+
// Return the actual return type or an explicit !cir.void if the function does
919+
// not return anything
920+
mlir::Type FuncType::getReturnType() const {
921+
if (isVoid())
922+
return cir::VoidType::get(getContext());
923+
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnTypes.front();
924+
}
925+
926+
bool FuncType::isVoid() const {
927+
auto rt = static_cast<detail::FuncTypeStorage *>(getImpl())->returnTypes;
928+
assert(rt.empty() ||
929+
!mlir::isa<cir::VoidType>(rt.front()) &&
930+
"The return type for a function returning void should be empty "
931+
"instead of a real !cir.void");
932+
return rt.empty();
933+
}
865934

866935
//===----------------------------------------------------------------------===//
867936
// MethodType Definitions

clang/lib/CIR/Dialect/IR/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ add_clang_library(MLIRCIR
1919
LINK_LIBS PUBLIC
2020
MLIRIR
2121
MLIRCIRInterfaces
22+
MLIRCallInterfaces
2223
MLIRDLTIDialect
2324
MLIRDataLayoutInterfaces
25+
MLIRFunctionInterfaces
2426
MLIRFuncDialect
2527
MLIRLoopLikeInterface
2628
MLIRLLVMDialect

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerTypes.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ FuncType LowerTypes::getFunctionType(const LowerFunctionInfo &FI) {
109109
}
110110
}
111111

112-
return FuncType::get(getMLIRContext(), ArgTypes, resultType, FI.isVariadic());
112+
return FuncType::get(ArgTypes, resultType, FI.isVariadic());
113113
}
114114

115115
/// Convert a CIR type to its ABI-specific default form.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: cir-opt %s | FileCheck %s
2+
// Exercise different ways to encode a function returning void
3+
!s32i = !cir.int<s, 32>
4+
!fnptr1 = !cir.ptr<!cir.func<!cir.void(!s32i)>>
5+
// Note there is no !void alias defined
6+
!fnptr2 = !cir.ptr<!cir.func<!void(!s32i)>>
7+
!fnptr3 = !cir.ptr<!cir.func<(!s32i)>>
8+
module {
9+
cir.func @ind1(%fnptr: !fnptr1, %a : !s32i) {
10+
// CHECK: cir.func @ind1(%arg0: !cir.ptr<!cir.func<!void (!s32i)>>, %arg1: !s32i) {
11+
cir.return
12+
}
13+
14+
cir.func @ind2(%fnptr: !fnptr2, %a : !s32i) {
15+
// CHECK: cir.func @ind2(%arg0: !cir.ptr<!cir.func<!void (!s32i)>>, %arg1: !s32i) {
16+
cir.return
17+
}
18+
cir.func @ind3(%fnptr: !fnptr3, %a : !s32i) {
19+
// CHECK: cir.func @ind3(%arg0: !cir.ptr<!cir.func<!void (!s32i)>>, %arg1: !s32i) {
20+
cir.return
21+
}
22+
cir.func @f1() -> !cir.void {
23+
// CHECK: cir.func @f1() {
24+
cir.return
25+
}
26+
// Note there is no !void alias defined
27+
cir.func @f2() -> !void {
28+
// CHECK: cir.func @f2() {
29+
cir.return
30+
}
31+
cir.func @f3() {
32+
// CHECK: cir.func @f3() {
33+
cir.return
34+
}
35+
}

clang/test/CIR/IR/func.cir

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// RUN: cir-opt %s | FileCheck %s
2-
// XFAIL: *
32

43
!s32i = !cir.int<s, 32>
54
!u8i = !cir.int<u, 8>

clang/test/CIR/IR/invalid.cir

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// Test attempts to build bogus CIR
22
// RUN: cir-opt %s -verify-diagnostics -split-input-file
3-
// XFAIL: *
43

54
!u32i = !cir.int<u, 32>
65

clang/test/CIR/Lowering/call.cir

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// RUN: cir-opt %s -cir-to-llvm -o - | FileCheck %s -check-prefix=MLIR
22
// RUN: cir-translate %s -cir-to-llvmir --disable-cc-lowering | FileCheck %s -check-prefix=LLVM
3-
// XFAIL: *
43

54
!s32i = !cir.int<s, 32>
65
module {

clang/test/CIR/Lowering/func.cir

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// RUN: cir-opt %s -cir-to-llvm -o %t.mlir
22
// RUN: FileCheck %s -check-prefix=MLIR --input-file=%t.mlir
3-
// XFAIL: *
43

54
!s32i = !cir.int<s, 32>
65
module {

clang/test/CIR/Lowering/hello.cir

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// RUN: cir-opt %s -cir-to-llvm -o %t.mlir
22
// RUN: FileCheck --input-file=%t.mlir %s
3-
// XFAIL: *
43

54
!s32i = !cir.int<s, 32>
65
!s8i = !cir.int<s, 8>

clang/test/CIR/Lowering/variadics.cir

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// RUN: cir-opt %s -cir-to-llvm -o %t.cir
22
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=MLIR
3-
// XFAIL: *
43

54
!s32i = !cir.int<s, 32>
65
!u32i = !cir.int<u, 32>

mlir/include/mlir/IR/OpImplementation.h

+3
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,9 @@ class AsmParser {
923923
/// Parse an optional keyword or string.
924924
virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
925925

926+
/// Parse the given exclamation-prefixed keyword if present.
927+
virtual ParseResult parseOptionalExclamationKeyword(StringRef keyword) = 0;
928+
926929
//===--------------------------------------------------------------------===//
927930
// Attribute/Type Parsing
928931
//===--------------------------------------------------------------------===//

0 commit comments

Comments
 (0)