Skip to content

Commit ae71543

Browse files
committed
compiler: add support for 'go' on func values
This commit allows starting a new goroutine directly from a func value, not just when the static callee is known. This is necessary to support the whole time package, not just the commonly used subset that was compiled with the SimpleDCE pass enabled.
1 parent fd3309a commit ae71543

File tree

6 files changed

+153
-100
lines changed

6 files changed

+153
-100
lines changed

compiler/compiler.go

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,32 +1034,33 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) {
10341034
c.addError(instr.Pos(), "todo: go on method receiver")
10351035
return
10361036
}
1037-
callee := instr.Call.StaticCallee()
1038-
if callee == nil {
1039-
c.addError(instr.Pos(), "todo: go on non-direct function (function pointer, etc.)")
1040-
return
1041-
}
1042-
calleeFn := c.ir.GetFunction(callee)
1043-
1044-
// Mark this function as a 'go' invocation and break invalid
1045-
// interprocedural optimizations. For example, heap-to-stack
1046-
// transformations are not sound as goroutines can outlive their parent.
1047-
calleeType := calleeFn.LLVMFn.Type()
1048-
calleeValue := c.builder.CreatePtrToInt(calleeFn.LLVMFn, c.uintptrType, "")
1049-
calleeValue = c.createRuntimeCall("makeGoroutine", []llvm.Value{calleeValue}, "")
1050-
calleeValue = c.builder.CreateIntToPtr(calleeValue, calleeType, "")
10511037

10521038
// Get all function parameters to pass to the goroutine.
10531039
var params []llvm.Value
10541040
for _, param := range instr.Call.Args {
10551041
params = append(params, c.getValue(frame, param))
10561042
}
1057-
if !calleeFn.IsExported() {
1058-
params = append(params, llvm.Undef(c.i8ptrType)) // context parameter
1059-
params = append(params, llvm.Undef(c.i8ptrType)) // parent coroutine handle
1060-
}
10611043

1062-
c.createCall(calleeValue, params, "")
1044+
// Mark this function as a 'go' invocation and break invalid
1045+
// interprocedural optimizations. For example, heap-to-stack
1046+
// transformations are not sound as goroutines can outlive their parent.
1047+
if callee := instr.Call.StaticCallee(); callee != nil {
1048+
// Static callee is known: this is a regular function call.
1049+
calleeFn := c.ir.GetFunction(callee)
1050+
if !calleeFn.IsExported() {
1051+
params = append(params, llvm.Undef(c.i8ptrType)) // context parameter
1052+
params = append(params, c.getZeroValue(c.i8ptrType)) // parent coroutine handle
1053+
}
1054+
c.emitStartGoroutine(calleeFn.LLVMFn, params)
1055+
} else if !instr.Call.IsInvoke() {
1056+
// Start a new goroutine by calling a function pointer.
1057+
funcPtr, context := c.decodeFuncValue(c.getValue(frame, instr.Call.Value), instr.Call.Value.Type().(*types.Signature))
1058+
params = append(params, context) // context parameter
1059+
params = append(params, c.getZeroValue(c.i8ptrType)) // parent coroutine handle
1060+
c.emitStartGoroutine(funcPtr, params)
1061+
} else {
1062+
c.addError(instr.Pos(), "todo: go on interface call")
1063+
}
10631064
case *ssa.If:
10641065
cond := c.getValue(frame, instr.Cond)
10651066
block := instr.Block()

compiler/func-lowering.go

Lines changed: 91 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,11 @@ func (c *Compiler) LowerFuncValues() {
152152
// There are multiple functions used in a func value that
153153
// implement this signature.
154154
// What we'll do is transform the following:
155-
// rawPtr := runtime.getFuncPtr(fn)
156-
// if func.rawPtr == nil {
155+
// rawPtr := runtime.getFuncPtr(func.ptr)
156+
// if rawPtr == nil {
157157
// runtime.nilPanic()
158158
// }
159-
// result := func.rawPtr(...args, func.context)
159+
// result := rawPtr(...args, func.context)
160160
// into this:
161161
// if false {
162162
// runtime.nilPanic()
@@ -175,95 +175,111 @@ func (c *Compiler) LowerFuncValues() {
175175

176176
// Remove some casts, checks, and the old call which we're going
177177
// to replace.
178-
var funcCall llvm.Value
179-
for _, inttoptr := range getUses(getFuncPtrCall) {
180-
if inttoptr.IsAIntToPtrInst().IsNil() {
178+
for _, callIntPtr := range getUses(getFuncPtrCall) {
179+
if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "runtime.makeGoroutine" {
180+
for _, inttoptr := range getUses(callIntPtr) {
181+
if inttoptr.IsAIntToPtrInst().IsNil() {
182+
panic("expected a inttoptr")
183+
}
184+
for _, use := range getUses(inttoptr) {
185+
c.addFuncLoweringSwitch(funcID, use, c.emitStartGoroutine, functions)
186+
use.EraseFromParentAsInstruction()
187+
}
188+
inttoptr.EraseFromParentAsInstruction()
189+
}
190+
callIntPtr.EraseFromParentAsInstruction()
191+
continue
192+
}
193+
if callIntPtr.IsAIntToPtrInst().IsNil() {
181194
panic("expected inttoptr")
182195
}
183-
for _, ptrUse := range getUses(inttoptr) {
196+
for _, ptrUse := range getUses(callIntPtr) {
184197
if !ptrUse.IsABitCastInst().IsNil() {
185198
for _, bitcastUse := range getUses(ptrUse) {
186-
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().Name() != "runtime.isnil" {
199+
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().IsAFunction().IsNil() {
200+
panic("expected a call instruction")
201+
}
202+
switch bitcastUse.CalledValue().Name() {
203+
case "runtime.isnil":
204+
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
205+
bitcastUse.EraseFromParentAsInstruction()
206+
default:
187207
panic("expected a call to runtime.isnil")
188208
}
189-
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
190-
bitcastUse.EraseFromParentAsInstruction()
191209
}
192-
ptrUse.EraseFromParentAsInstruction()
193-
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == inttoptr {
194-
if !funcCall.IsNil() {
195-
panic("multiple calls on a single runtime.getFuncPtr")
196-
}
197-
funcCall = ptrUse
210+
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr {
211+
c.addFuncLoweringSwitch(funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
212+
return c.builder.CreateCall(funcPtr, params, "")
213+
}, functions)
198214
} else {
199215
panic("unexpected getFuncPtrCall")
200216
}
217+
ptrUse.EraseFromParentAsInstruction()
201218
}
219+
callIntPtr.EraseFromParentAsInstruction()
202220
}
203-
if funcCall.IsNil() {
204-
panic("expected exactly one call use of a runtime.getFuncPtr")
205-
}
206-
207-
// The block that cannot be reached with correct funcValues (to
208-
// help the optimizer).
209-
c.builder.SetInsertPointBefore(funcCall)
210-
defaultBlock := llvm.AddBasicBlock(funcCall.InstructionParent().Parent(), "func.default")
211-
c.builder.SetInsertPointAtEnd(defaultBlock)
212-
c.builder.CreateUnreachable()
221+
getFuncPtrCall.EraseFromParentAsInstruction()
222+
}
223+
}
224+
}
225+
}
213226

214-
// Create the switch.
215-
c.builder.SetInsertPointBefore(funcCall)
216-
sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
227+
// addFuncLoweringSwitch creates a new switch on a function ID and inserts calls
228+
// to the newly created direct calls. The funcID is the number to switch on,
229+
// call is the call instruction to replace, and createCall is the callback that
230+
// actually creates the new call. By changing createCall to something other than
231+
// c.builder.CreateCall, instead of calling a function it can start a new
232+
// goroutine for example.
233+
func (c *Compiler) addFuncLoweringSwitch(funcID, call llvm.Value, createCall func(funcPtr llvm.Value, params []llvm.Value) llvm.Value, functions funcWithUsesList) {
234+
// The block that cannot be reached with correct funcValues (to help the
235+
// optimizer).
236+
c.builder.SetInsertPointBefore(call)
237+
defaultBlock := llvm.AddBasicBlock(call.InstructionParent().Parent(), "func.default")
238+
c.builder.SetInsertPointAtEnd(defaultBlock)
239+
c.builder.CreateUnreachable()
217240

218-
// Split right after the switch. We will need to insert a few
219-
// basic blocks in this gap.
220-
nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
241+
// Create the switch.
242+
c.builder.SetInsertPointBefore(call)
243+
sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
221244

222-
// The 0 case, which is actually a nil check.
223-
nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil")
224-
c.builder.SetInsertPointAtEnd(nilBlock)
225-
c.createRuntimeCall("nilPanic", nil, "")
226-
c.builder.CreateUnreachable()
227-
sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
245+
// Split right after the switch. We will need to insert a few basic blocks
246+
// in this gap.
247+
nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
228248

229-
// Gather the list of parameters for every call we're going to
230-
// make.
231-
callParams := make([]llvm.Value, funcCall.OperandsCount()-1)
232-
for i := range callParams {
233-
callParams[i] = funcCall.Operand(i)
234-
}
249+
// The 0 case, which is actually a nil check.
250+
nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil")
251+
c.builder.SetInsertPointAtEnd(nilBlock)
252+
c.createRuntimeCall("nilPanic", nil, "")
253+
c.builder.CreateUnreachable()
254+
sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
235255

236-
// If the call produces a value, we need to get it using a PHI
237-
// node.
238-
phiBlocks := make([]llvm.BasicBlock, len(functions))
239-
phiValues := make([]llvm.Value, len(functions))
240-
for i, fn := range functions {
241-
// Insert a switch case.
242-
bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
243-
c.builder.SetInsertPointAtEnd(bb)
244-
result := c.builder.CreateCall(fn.funcPtr, callParams, "")
245-
c.builder.CreateBr(nextBlock)
246-
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
247-
phiBlocks[i] = bb
248-
phiValues[i] = result
249-
}
250-
// Create the PHI node so that the call result flows into the
251-
// next block (after the split). This is only necessary when the
252-
// call produced a value.
253-
if funcCall.Type().TypeKind() != llvm.VoidTypeKind {
254-
c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
255-
phi := c.builder.CreatePHI(funcCall.Type(), "")
256-
phi.AddIncoming(phiValues, phiBlocks)
257-
funcCall.ReplaceAllUsesWith(phi)
258-
}
256+
// Gather the list of parameters for every call we're going to make.
257+
callParams := make([]llvm.Value, call.OperandsCount()-1)
258+
for i := range callParams {
259+
callParams[i] = call.Operand(i)
260+
}
259261

260-
// Finally, remove the old instructions.
261-
funcCall.EraseFromParentAsInstruction()
262-
for _, inttoptr := range getUses(getFuncPtrCall) {
263-
inttoptr.EraseFromParentAsInstruction()
264-
}
265-
getFuncPtrCall.EraseFromParentAsInstruction()
266-
}
267-
}
262+
// If the call produces a value, we need to get it using a PHI
263+
// node.
264+
phiBlocks := make([]llvm.BasicBlock, len(functions))
265+
phiValues := make([]llvm.Value, len(functions))
266+
for i, fn := range functions {
267+
// Insert a switch case.
268+
bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
269+
c.builder.SetInsertPointAtEnd(bb)
270+
result := createCall(fn.funcPtr, callParams)
271+
c.builder.CreateBr(nextBlock)
272+
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
273+
phiBlocks[i] = bb
274+
phiValues[i] = result
275+
}
276+
// Create the PHI node so that the call result flows into the
277+
// next block (after the split). This is only necessary when the
278+
// call produced a value.
279+
if call.Type().TypeKind() != llvm.VoidTypeKind {
280+
c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
281+
phi := c.builder.CreatePHI(call.Type(), "")
282+
phi.AddIncoming(phiValues, phiBlocks)
283+
call.ReplaceAllUsesWith(phi)
268284
}
269285
}

compiler/func.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ const (
3232
// funcImplementation picks an appropriate func value implementation for the
3333
// target.
3434
func (c *Compiler) funcImplementation() funcValueImplementation {
35-
if c.GOARCH == "wasm" {
36-
return funcValueSwitch
37-
} else {
38-
return funcValueDoubleword
39-
}
35+
// Always pick the switch implementation, as it allows the use of blocking
36+
// inside a function that is used as a func value.
37+
return funcValueSwitch
4038
}
4139

4240
// createFuncValue creates a function value from a raw function pointer with no

compiler/goroutine.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package compiler
2+
3+
// This file implements the 'go' keyword to start a new goroutine. See
4+
// goroutine-lowering.go for more details.
5+
6+
import "tinygo.org/x/go-llvm"
7+
8+
// emitStartGoroutine starts a new goroutine with the provided function pointer
9+
// and parameters.
10+
func (c *Compiler) emitStartGoroutine(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
11+
// We roundtrip through runtime.makeGoroutine as a signal (to find these
12+
// calls) and to break any optimizations LLVM will try to do: they are
13+
// invalid if we called this as a regular function to be updated later.
14+
calleeValue := c.builder.CreatePtrToInt(funcPtr, c.uintptrType, "")
15+
calleeValue = c.createRuntimeCall("makeGoroutine", []llvm.Value{calleeValue}, "")
16+
calleeValue = c.builder.CreateIntToPtr(calleeValue, funcPtr.Type(), "")
17+
c.createCall(calleeValue, params, "")
18+
return llvm.Undef(funcPtr.Type().ElementType().ReturnType())
19+
}

testdata/coroutines.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,19 @@ func main() {
2828
var printer Printer
2929
printer = &myPrinter{}
3030
printer.Print()
31+
32+
sleepFuncValue(func(x int) {
33+
time.Sleep(1 * time.Millisecond)
34+
println("slept inside func pointer", x)
35+
})
36+
time.Sleep(1 * time.Millisecond)
37+
n := 20
38+
sleepFuncValue(func(x int) {
39+
time.Sleep(1 * time.Millisecond)
40+
println("slept inside closure, with value:", n, x)
41+
})
42+
43+
time.Sleep(2 * time.Millisecond)
3144
}
3245

3346
func sub() {
@@ -47,6 +60,10 @@ func delayedValue() int {
4760
return 42
4861
}
4962

63+
func sleepFuncValue(fn func(int)) {
64+
go fn(8)
65+
}
66+
5067
func nowait() {
5168
println("non-blocking goroutine")
5269
}
@@ -55,7 +72,7 @@ type Printer interface {
5572
Print()
5673
}
5774

58-
type myPrinter struct{
75+
type myPrinter struct {
5976
}
6077

6178
func (i *myPrinter) Print() {

testdata/coroutines.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ value produced after some time: 42
1111
non-blocking goroutine
1212
done with non-blocking goroutine
1313
async interface method call
14+
slept inside func pointer 8
15+
slept inside closure, with value: 20 8

0 commit comments

Comments
 (0)