Skip to content

Commit 5fdbd3f

Browse files
committed
compiler: add support for 'go' on func values
This commit allows starting a new goroutine directly from a func value, not just when the static callee is known. This is necessary to support the whole time package, not just the commonly used subset that was compiled with the SimpleDCE pass enabled.
1 parent 542135c commit 5fdbd3f

File tree

6 files changed

+234
-127
lines changed

6 files changed

+234
-127
lines changed

compiler/compiler.go

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,32 +1059,47 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) {
10591059
case *ssa.Defer:
10601060
c.emitDefer(frame, instr)
10611061
case *ssa.Go:
1062-
if instr.Call.IsInvoke() {
1063-
c.addError(instr.Pos(), "todo: go on method receiver")
1064-
return
1065-
}
1066-
callee := instr.Call.StaticCallee()
1067-
if callee == nil {
1068-
c.addError(instr.Pos(), "todo: go on non-direct function (function pointer, etc.)")
1069-
return
1070-
}
1071-
calleeFn := c.ir.GetFunction(callee)
1072-
10731062
// Get all function parameters to pass to the goroutine.
10741063
var params []llvm.Value
10751064
for _, param := range instr.Call.Args {
10761065
params = append(params, c.getValue(frame, param))
10771066
}
1078-
if !calleeFn.IsExported() && c.selectScheduler() != "tasks" {
1079-
// For coroutine scheduling, this is only required when calling an
1080-
// external function.
1081-
// For tasks, because all params are stored in a single object, no
1082-
// unnecessary parameters should be stored anyway.
1083-
params = append(params, llvm.Undef(c.i8ptrType)) // context parameter
1084-
params = append(params, llvm.Undef(c.i8ptrType)) // parent coroutine handle
1085-
}
10861067

1087-
c.emitStartGoroutine(calleeFn.LLVMFn, params)
1068+
// Start a new goroutine.
1069+
if callee := instr.Call.StaticCallee(); callee != nil {
1070+
// Static callee is known. This makes it easier to start a new
1071+
// goroutine.
1072+
calleeFn := c.ir.GetFunction(callee)
1073+
if !calleeFn.IsExported() && c.selectScheduler() != "tasks" {
1074+
// For coroutine scheduling, this is only required when calling
1075+
// an external function.
1076+
// For tasks, because all params are stored in a single object,
1077+
// no unnecessary parameters should be stored anyway.
1078+
params = append(params, llvm.Undef(c.i8ptrType)) // context parameter
1079+
params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle
1080+
}
1081+
c.emitStartGoroutine(calleeFn.LLVMFn, params)
1082+
} else if !instr.Call.IsInvoke() {
1083+
// This is a function pointer.
1084+
// At the moment, two extra params are passed to the newly started
1085+
// goroutine:
1086+
// * The function context, for closures.
1087+
// * The parent handle (for coroutines) or the function pointer
1088+
// itself (for tasks).
1089+
funcPtr, context := c.decodeFuncValue(c.getValue(frame, instr.Call.Value), instr.Call.Value.Type().(*types.Signature))
1090+
params = append(params, context) // context parameter
1091+
switch c.selectScheduler() {
1092+
case "coroutines":
1093+
params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle
1094+
case "tasks":
1095+
params = append(params, funcPtr)
1096+
default:
1097+
panic("unknown scheduler type")
1098+
}
1099+
c.emitStartGoroutine(funcPtr, params)
1100+
} else {
1101+
c.addError(instr.Pos(), "todo: go on interface call")
1102+
}
10881103
case *ssa.If:
10891104
cond := c.getValue(frame, instr.Cond)
10901105
block := instr.Block()

compiler/func-lowering.go

Lines changed: 91 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,11 @@ func (c *Compiler) LowerFuncValues() {
152152
// There are multiple functions used in a func value that
153153
// implement this signature.
154154
// What we'll do is transform the following:
155-
// rawPtr := runtime.getFuncPtr(fn)
156-
// if func.rawPtr == nil {
155+
// rawPtr := runtime.getFuncPtr(func.ptr)
156+
// if rawPtr == nil {
157157
// runtime.nilPanic()
158158
// }
159-
// result := func.rawPtr(...args, func.context)
159+
// result := rawPtr(...args, func.context)
160160
// into this:
161161
// if false {
162162
// runtime.nilPanic()
@@ -175,95 +175,111 @@ func (c *Compiler) LowerFuncValues() {
175175

176176
// Remove some casts, checks, and the old call which we're going
177177
// to replace.
178-
var funcCall llvm.Value
179-
for _, inttoptr := range getUses(getFuncPtrCall) {
180-
if inttoptr.IsAIntToPtrInst().IsNil() {
178+
for _, callIntPtr := range getUses(getFuncPtrCall) {
179+
if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "runtime.makeGoroutine" {
180+
for _, inttoptr := range getUses(callIntPtr) {
181+
if inttoptr.IsAIntToPtrInst().IsNil() {
182+
panic("expected a inttoptr")
183+
}
184+
for _, use := range getUses(inttoptr) {
185+
c.addFuncLoweringSwitch(funcID, use, c.emitStartGoroutine, functions)
186+
use.EraseFromParentAsInstruction()
187+
}
188+
inttoptr.EraseFromParentAsInstruction()
189+
}
190+
callIntPtr.EraseFromParentAsInstruction()
191+
continue
192+
}
193+
if callIntPtr.IsAIntToPtrInst().IsNil() {
181194
panic("expected inttoptr")
182195
}
183-
for _, ptrUse := range getUses(inttoptr) {
196+
for _, ptrUse := range getUses(callIntPtr) {
184197
if !ptrUse.IsABitCastInst().IsNil() {
185198
for _, bitcastUse := range getUses(ptrUse) {
186-
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().Name() != "runtime.isnil" {
199+
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().IsAFunction().IsNil() {
200+
panic("expected a call instruction")
201+
}
202+
switch bitcastUse.CalledValue().Name() {
203+
case "runtime.isnil":
204+
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
205+
bitcastUse.EraseFromParentAsInstruction()
206+
default:
187207
panic("expected a call to runtime.isnil")
188208
}
189-
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
190-
bitcastUse.EraseFromParentAsInstruction()
191209
}
192-
ptrUse.EraseFromParentAsInstruction()
193-
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == inttoptr {
194-
if !funcCall.IsNil() {
195-
panic("multiple calls on a single runtime.getFuncPtr")
196-
}
197-
funcCall = ptrUse
210+
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr {
211+
c.addFuncLoweringSwitch(funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
212+
return c.builder.CreateCall(funcPtr, params, "")
213+
}, functions)
198214
} else {
199215
panic("unexpected getFuncPtrCall")
200216
}
217+
ptrUse.EraseFromParentAsInstruction()
201218
}
219+
callIntPtr.EraseFromParentAsInstruction()
202220
}
203-
if funcCall.IsNil() {
204-
panic("expected exactly one call use of a runtime.getFuncPtr")
205-
}
206-
207-
// The block that cannot be reached with correct funcValues (to
208-
// help the optimizer).
209-
c.builder.SetInsertPointBefore(funcCall)
210-
defaultBlock := llvm.AddBasicBlock(funcCall.InstructionParent().Parent(), "func.default")
211-
c.builder.SetInsertPointAtEnd(defaultBlock)
212-
c.builder.CreateUnreachable()
221+
getFuncPtrCall.EraseFromParentAsInstruction()
222+
}
223+
}
224+
}
225+
}
213226

214-
// Create the switch.
215-
c.builder.SetInsertPointBefore(funcCall)
216-
sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
227+
// addFuncLoweringSwitch creates a new switch on a function ID and inserts calls
228+
// to the newly created direct calls. The funcID is the number to switch on,
229+
// call is the call instruction to replace, and createCall is the callback that
230+
// actually creates the new call. By changing createCall to something other than
231+
// c.builder.CreateCall, instead of calling a function it can start a new
232+
// goroutine for example.
233+
func (c *Compiler) addFuncLoweringSwitch(funcID, call llvm.Value, createCall func(funcPtr llvm.Value, params []llvm.Value) llvm.Value, functions funcWithUsesList) {
234+
// The block that cannot be reached with correct funcValues (to help the
235+
// optimizer).
236+
c.builder.SetInsertPointBefore(call)
237+
defaultBlock := llvm.AddBasicBlock(call.InstructionParent().Parent(), "func.default")
238+
c.builder.SetInsertPointAtEnd(defaultBlock)
239+
c.builder.CreateUnreachable()
217240

218-
// Split right after the switch. We will need to insert a few
219-
// basic blocks in this gap.
220-
nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
241+
// Create the switch.
242+
c.builder.SetInsertPointBefore(call)
243+
sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
221244

222-
// The 0 case, which is actually a nil check.
223-
nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil")
224-
c.builder.SetInsertPointAtEnd(nilBlock)
225-
c.createRuntimeCall("nilPanic", nil, "")
226-
c.builder.CreateUnreachable()
227-
sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
245+
// Split right after the switch. We will need to insert a few basic blocks
246+
// in this gap.
247+
nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
228248

229-
// Gather the list of parameters for every call we're going to
230-
// make.
231-
callParams := make([]llvm.Value, funcCall.OperandsCount()-1)
232-
for i := range callParams {
233-
callParams[i] = funcCall.Operand(i)
234-
}
249+
// The 0 case, which is actually a nil check.
250+
nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil")
251+
c.builder.SetInsertPointAtEnd(nilBlock)
252+
c.createRuntimeCall("nilPanic", nil, "")
253+
c.builder.CreateUnreachable()
254+
sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
235255

236-
// If the call produces a value, we need to get it using a PHI
237-
// node.
238-
phiBlocks := make([]llvm.BasicBlock, len(functions))
239-
phiValues := make([]llvm.Value, len(functions))
240-
for i, fn := range functions {
241-
// Insert a switch case.
242-
bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
243-
c.builder.SetInsertPointAtEnd(bb)
244-
result := c.builder.CreateCall(fn.funcPtr, callParams, "")
245-
c.builder.CreateBr(nextBlock)
246-
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
247-
phiBlocks[i] = bb
248-
phiValues[i] = result
249-
}
250-
// Create the PHI node so that the call result flows into the
251-
// next block (after the split). This is only necessary when the
252-
// call produced a value.
253-
if funcCall.Type().TypeKind() != llvm.VoidTypeKind {
254-
c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
255-
phi := c.builder.CreatePHI(funcCall.Type(), "")
256-
phi.AddIncoming(phiValues, phiBlocks)
257-
funcCall.ReplaceAllUsesWith(phi)
258-
}
256+
// Gather the list of parameters for every call we're going to make.
257+
callParams := make([]llvm.Value, call.OperandsCount()-1)
258+
for i := range callParams {
259+
callParams[i] = call.Operand(i)
260+
}
259261

260-
// Finally, remove the old instructions.
261-
funcCall.EraseFromParentAsInstruction()
262-
for _, inttoptr := range getUses(getFuncPtrCall) {
263-
inttoptr.EraseFromParentAsInstruction()
264-
}
265-
getFuncPtrCall.EraseFromParentAsInstruction()
266-
}
267-
}
262+
// If the call produces a value, we need to get it using a PHI
263+
// node.
264+
phiBlocks := make([]llvm.BasicBlock, len(functions))
265+
phiValues := make([]llvm.Value, len(functions))
266+
for i, fn := range functions {
267+
// Insert a switch case.
268+
bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
269+
c.builder.SetInsertPointAtEnd(bb)
270+
result := createCall(fn.funcPtr, callParams)
271+
c.builder.CreateBr(nextBlock)
272+
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
273+
phiBlocks[i] = bb
274+
phiValues[i] = result
275+
}
276+
// Create the PHI node so that the call result flows into the
277+
// next block (after the split). This is only necessary when the
278+
// call produced a value.
279+
if call.Type().TypeKind() != llvm.VoidTypeKind {
280+
c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
281+
phi := c.builder.CreatePHI(call.Type(), "")
282+
phi.AddIncoming(phiValues, phiBlocks)
283+
call.ReplaceAllUsesWith(phi)
268284
}
269285
}

compiler/func.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,15 @@ const (
3232
// funcImplementation picks an appropriate func value implementation for the
3333
// target.
3434
func (c *Compiler) funcImplementation() funcValueImplementation {
35-
if c.GOARCH == "wasm" {
35+
// Always pick the switch implementation, as it allows the use of blocking
36+
// inside a function that is used as a func value.
37+
switch c.selectScheduler() {
38+
case "coroutines":
3639
return funcValueSwitch
37-
} else {
40+
case "tasks":
3841
return funcValueDoubleword
42+
default:
43+
panic("unknown scheduler type")
3944
}
4045
}
4146

0 commit comments

Comments
 (0)