Skip to content

Commit 471d08c

Browse files
aykevlniaow
authored andcommitted
compiler: add support for recursive function types
This adds support for a construct like this: type foo func(fn foo) Unfortunately, LLVM cannot create function pointers that look like this. LLVM only supports named types for structs (not for pointers) and thus can't add a pointer to a function type of the same type to a parameter of that function type. The fix is simple: cast all function pointers to a void function, in LLVM IR: void ()* Raw function pointers are cast to this type before storing, and cast back to the regular function type before calling. This means that function parameters will never refer to its own type because raw function types are fixed at that one type. Somehow, this does have an effect on binary size in some cases. The effect is small and goes both ways. On top of that, there is work underway in LLVM which would make all pointer types opaque (without a pointee type). This would make this whole commit useless and therefore should fix any size increases that might happen. https://llvm.org/docs/OpaquePointers.html
1 parent 4199be9 commit 471d08c

File tree

5 files changed

+18
-10
lines changed

5 files changed

+18
-10
lines changed

compiler/compiler.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import (
2323
// Version of the compiler pacakge. Must be incremented each time the compiler
2424
// package changes in a way that affects the generated LLVM module.
2525
// This version is independent of the TinyGo version number.
26-
const Version = 22 // last change: check for divide by zero
26+
const Version = 23 // last change: fix recursive function types
2727

2828
func init() {
2929
llvm.InitializeAllTargets()
@@ -76,6 +76,7 @@ type compilerContext struct {
7676
targetData llvm.TargetData
7777
intType llvm.Type
7878
i8ptrType llvm.Type // for convenience
79+
rawVoidFuncType llvm.Type // for convenience
7980
funcPtrAddrSpace int
8081
uintptrType llvm.Type
8182
program *ssa.Program
@@ -121,6 +122,7 @@ func newCompilerContext(moduleName string, machine llvm.TargetMachine, config *C
121122
dummyFuncType := llvm.FunctionType(c.ctx.VoidType(), nil, false)
122123
dummyFunc := llvm.AddFunction(c.mod, "tinygo.dummy", dummyFuncType)
123124
c.funcPtrAddrSpace = dummyFunc.Type().PointerAddressSpace()
125+
c.rawVoidFuncType = dummyFunc.Type()
124126
dummyFunc.EraseFromParentAsFunction()
125127

126128
return c

compiler/func.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func (c *compilerContext) createFuncValue(builder llvm.Builder, funcPtr, context
2323
switch c.FuncImplementation {
2424
case "doubleword":
2525
// Closure is: {context, function pointer}
26-
funcValueScalar = funcPtr
26+
funcValueScalar = llvm.ConstBitCast(funcPtr, c.rawVoidFuncType)
2727
case "switch":
2828
funcValueWithSignatureGlobalName := funcPtr.Name() + "$withSignature"
2929
funcValueWithSignatureGlobal := c.mod.NamedGlobal(funcValueWithSignatureGlobalName)
@@ -78,11 +78,11 @@ func (b *builder) extractFuncContext(funcValue llvm.Value) llvm.Value {
7878
// value. This may be an expensive operation.
7979
func (b *builder) decodeFuncValue(funcValue llvm.Value, sig *types.Signature) (funcPtr, context llvm.Value) {
8080
context = b.CreateExtractValue(funcValue, 0, "")
81+
llvmSig := b.getRawFuncType(sig)
8182
switch b.FuncImplementation {
8283
case "doubleword":
83-
funcPtr = b.CreateExtractValue(funcValue, 1, "")
84+
funcPtr = b.CreateBitCast(b.CreateExtractValue(funcValue, 1, ""), llvmSig, "")
8485
case "switch":
85-
llvmSig := b.getRawFuncType(sig)
8686
sigGlobal := b.getFuncSignatureID(sig)
8787
funcPtr = b.createRuntimeCall("getFuncPtr", []llvm.Value{funcValue, sigGlobal}, "")
8888
funcPtr = b.CreateIntToPtr(funcPtr, llvmSig, "")
@@ -96,8 +96,7 @@ func (b *builder) decodeFuncValue(funcValue llvm.Value, sig *types.Signature) (f
9696
func (c *compilerContext) getFuncType(typ *types.Signature) llvm.Type {
9797
switch c.FuncImplementation {
9898
case "doubleword":
99-
rawPtr := c.getRawFuncType(typ)
100-
return c.ctx.StructType([]llvm.Type{c.i8ptrType, rawPtr}, false)
99+
return c.ctx.StructType([]llvm.Type{c.i8ptrType, c.rawVoidFuncType}, false)
101100
case "switch":
102101
return c.getLLVMRuntimeType("funcValue")
103102
default:

compiler/testdata/goroutine-cortex-m-qemu.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ entry:
103103
declare void @runtime.printint32(i32, i8*, i8*)
104104

105105
; Function Attrs: nounwind
106-
define hidden void @main.funcGoroutine(i8* %fn.context, void (i32, i8*, i8*)* %fn.funcptr, i8* %context, i8* %parentHandle) unnamed_addr #0 {
106+
define hidden void @main.funcGoroutine(i8* %fn.context, void ()* %fn.funcptr, i8* %context, i8* %parentHandle) unnamed_addr #0 {
107107
entry:
108108
%0 = call i8* @runtime.alloc(i32 12, i8* undef, i8* null) #0
109109
%1 = bitcast i8* %0 to i32*
@@ -112,8 +112,8 @@ entry:
112112
%3 = bitcast i8* %2 to i8**
113113
store i8* %fn.context, i8** %3, align 4
114114
%4 = getelementptr inbounds i8, i8* %0, i32 8
115-
%5 = bitcast i8* %4 to void (i32, i8*, i8*)**
116-
store void (i32, i8*, i8*)* %fn.funcptr, void (i32, i8*, i8*)** %5, align 4
115+
%5 = bitcast i8* %4 to void ()**
116+
store void ()* %fn.funcptr, void ()** %5, align 4
117117
%stacksize = call i32 @"internal/task.getGoroutineStackSize"(i32 ptrtoint (void (i8*)* @main.funcGoroutine.gowrapper to i32), i8* undef, i8* undef) #0
118118
call void @"internal/task.start"(i32 ptrtoint (void (i8*)* @main.funcGoroutine.gowrapper to i32), i8* nonnull %0, i32 %stacksize, i8* undef, i8* null) #0
119119
ret void

testdata/calls.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,7 @@ type issue1304 struct {
228228
func (x issue1304) call() {
229229
// nothing to do
230230
}
231+
232+
type recursiveFuncType func(recursiveFuncType)
233+
234+
var recursiveFunction recursiveFuncType

transform/interrupt.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ func LowerInterrupts(mod llvm.Module, sizeLevel int) []error {
6666
softwareVector := make(map[int64]llvm.Value)
6767

6868
ctx := mod.Context()
69-
nullptr := llvm.ConstNull(llvm.PointerType(ctx.Int8Type(), 0))
69+
i8ptrType := llvm.PointerType(ctx.Int8Type(), 0)
70+
nullptr := llvm.ConstNull(i8ptrType)
7071
builder := ctx.NewBuilder()
7172
defer builder.Dispose()
7273

@@ -236,6 +237,8 @@ func LowerInterrupts(mod llvm.Module, sizeLevel int) []error {
236237
// Fill the function declaration with the forwarding call.
237238
// In practice, the called function will often be inlined which avoids
238239
// the extra indirection.
240+
handlerFuncPtrType := llvm.PointerType(llvm.FunctionType(ctx.VoidType(), []llvm.Type{num.Type(), i8ptrType, i8ptrType}, false), handlerFuncPtr.Type().PointerAddressSpace())
241+
handlerFuncPtr = llvm.ConstBitCast(handlerFuncPtr, handlerFuncPtrType)
239242
builder.CreateCall(handlerFuncPtr, []llvm.Value{num, handlerContext, nullptr}, "")
240243

241244
// Replace all ptrtoint uses of the global with the interrupt constant.

0 commit comments

Comments
 (0)