Skip to content

Commit bbc3046

Browse files
aykevldeadprogram
authored andcommitted
compiler: add support for 'go' on func values
This commit allows starting a new goroutine directly from a func value, not just when the static callee is known. This is necessary to support the whole time package, not just the commonly used subset that was compiled with the SimpleDCE pass enabled.
1 parent e4fc3bb commit bbc3046

File tree

6 files changed

+234
-127
lines changed

6 files changed

+234
-127
lines changed

compiler/compiler.go

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,32 +1060,47 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) {
10601060
case *ssa.Defer:
10611061
c.emitDefer(frame, instr)
10621062
case *ssa.Go:
1063-
if instr.Call.IsInvoke() {
1064-
c.addError(instr.Pos(), "todo: go on method receiver")
1065-
return
1066-
}
1067-
callee := instr.Call.StaticCallee()
1068-
if callee == nil {
1069-
c.addError(instr.Pos(), "todo: go on non-direct function (function pointer, etc.)")
1070-
return
1071-
}
1072-
calleeFn := c.ir.GetFunction(callee)
1073-
10741063
// Get all function parameters to pass to the goroutine.
10751064
var params []llvm.Value
10761065
for _, param := range instr.Call.Args {
10771066
params = append(params, c.getValue(frame, param))
10781067
}
1079-
if !calleeFn.IsExported() && c.selectScheduler() != "tasks" {
1080-
// For coroutine scheduling, this is only required when calling an
1081-
// external function.
1082-
// For tasks, because all params are stored in a single object, no
1083-
// unnecessary parameters should be stored anyway.
1084-
params = append(params, llvm.Undef(c.i8ptrType)) // context parameter
1085-
params = append(params, llvm.Undef(c.i8ptrType)) // parent coroutine handle
1086-
}
10871068

1088-
c.emitStartGoroutine(calleeFn.LLVMFn, params)
1069+
// Start a new goroutine.
1070+
if callee := instr.Call.StaticCallee(); callee != nil {
1071+
// Static callee is known. This makes it easier to start a new
1072+
// goroutine.
1073+
calleeFn := c.ir.GetFunction(callee)
1074+
if !calleeFn.IsExported() && c.selectScheduler() != "tasks" {
1075+
// For coroutine scheduling, this is only required when calling
1076+
// an external function.
1077+
// For tasks, because all params are stored in a single object,
1078+
// no unnecessary parameters should be stored anyway.
1079+
params = append(params, llvm.Undef(c.i8ptrType)) // context parameter
1080+
params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle
1081+
}
1082+
c.emitStartGoroutine(calleeFn.LLVMFn, params)
1083+
} else if !instr.Call.IsInvoke() {
1084+
// This is a function pointer.
1085+
// At the moment, two extra params are passed to the newly started
1086+
// goroutine:
1087+
// * The function context, for closures.
1088+
// * The parent handle (for coroutines) or the function pointer
1089+
// itself (for tasks).
1090+
funcPtr, context := c.decodeFuncValue(c.getValue(frame, instr.Call.Value), instr.Call.Value.Type().(*types.Signature))
1091+
params = append(params, context) // context parameter
1092+
switch c.selectScheduler() {
1093+
case "coroutines":
1094+
params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle
1095+
case "tasks":
1096+
params = append(params, funcPtr)
1097+
default:
1098+
panic("unknown scheduler type")
1099+
}
1100+
c.emitStartGoroutine(funcPtr, params)
1101+
} else {
1102+
c.addError(instr.Pos(), "todo: go on interface call")
1103+
}
10891104
case *ssa.If:
10901105
cond := c.getValue(frame, instr.Cond)
10911106
block := instr.Block()

compiler/func-lowering.go

Lines changed: 91 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,11 @@ func (c *Compiler) LowerFuncValues() {
152152
// There are multiple functions used in a func value that
153153
// implement this signature.
154154
// What we'll do is transform the following:
155-
// rawPtr := runtime.getFuncPtr(fn)
156-
// if func.rawPtr == nil {
155+
// rawPtr := runtime.getFuncPtr(func.ptr)
156+
// if rawPtr == nil {
157157
// runtime.nilPanic()
158158
// }
159-
// result := func.rawPtr(...args, func.context)
159+
// result := rawPtr(...args, func.context)
160160
// into this:
161161
// if false {
162162
// runtime.nilPanic()
@@ -175,95 +175,111 @@ func (c *Compiler) LowerFuncValues() {
175175

176176
// Remove some casts, checks, and the old call which we're going
177177
// to replace.
178-
var funcCall llvm.Value
179-
for _, inttoptr := range getUses(getFuncPtrCall) {
180-
if inttoptr.IsAIntToPtrInst().IsNil() {
178+
for _, callIntPtr := range getUses(getFuncPtrCall) {
179+
if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "runtime.makeGoroutine" {
180+
for _, inttoptr := range getUses(callIntPtr) {
181+
if inttoptr.IsAIntToPtrInst().IsNil() {
182+
panic("expected a inttoptr")
183+
}
184+
for _, use := range getUses(inttoptr) {
185+
c.addFuncLoweringSwitch(funcID, use, c.emitStartGoroutine, functions)
186+
use.EraseFromParentAsInstruction()
187+
}
188+
inttoptr.EraseFromParentAsInstruction()
189+
}
190+
callIntPtr.EraseFromParentAsInstruction()
191+
continue
192+
}
193+
if callIntPtr.IsAIntToPtrInst().IsNil() {
181194
panic("expected inttoptr")
182195
}
183-
for _, ptrUse := range getUses(inttoptr) {
196+
for _, ptrUse := range getUses(callIntPtr) {
184197
if !ptrUse.IsABitCastInst().IsNil() {
185198
for _, bitcastUse := range getUses(ptrUse) {
186-
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().Name() != "runtime.isnil" {
199+
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().IsAFunction().IsNil() {
200+
panic("expected a call instruction")
201+
}
202+
switch bitcastUse.CalledValue().Name() {
203+
case "runtime.isnil":
204+
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
205+
bitcastUse.EraseFromParentAsInstruction()
206+
default:
187207
panic("expected a call to runtime.isnil")
188208
}
189-
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
190-
bitcastUse.EraseFromParentAsInstruction()
191209
}
192-
ptrUse.EraseFromParentAsInstruction()
193-
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == inttoptr {
194-
if !funcCall.IsNil() {
195-
panic("multiple calls on a single runtime.getFuncPtr")
196-
}
197-
funcCall = ptrUse
210+
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr {
211+
c.addFuncLoweringSwitch(funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
212+
return c.builder.CreateCall(funcPtr, params, "")
213+
}, functions)
198214
} else {
199215
panic("unexpected getFuncPtrCall")
200216
}
217+
ptrUse.EraseFromParentAsInstruction()
201218
}
219+
callIntPtr.EraseFromParentAsInstruction()
202220
}
203-
if funcCall.IsNil() {
204-
panic("expected exactly one call use of a runtime.getFuncPtr")
205-
}
206-
207-
// The block that cannot be reached with correct funcValues (to
208-
// help the optimizer).
209-
c.builder.SetInsertPointBefore(funcCall)
210-
defaultBlock := llvm.AddBasicBlock(funcCall.InstructionParent().Parent(), "func.default")
211-
c.builder.SetInsertPointAtEnd(defaultBlock)
212-
c.builder.CreateUnreachable()
221+
getFuncPtrCall.EraseFromParentAsInstruction()
222+
}
223+
}
224+
}
225+
}
213226

214-
// Create the switch.
215-
c.builder.SetInsertPointBefore(funcCall)
216-
sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
227+
// addFuncLoweringSwitch creates a new switch on a function ID and inserts calls
228+
// to the newly created direct calls. The funcID is the number to switch on,
229+
// call is the call instruction to replace, and createCall is the callback that
230+
// actually creates the new call. By changing createCall to something other than
231+
// c.builder.CreateCall, instead of calling a function it can start a new
232+
// goroutine for example.
233+
func (c *Compiler) addFuncLoweringSwitch(funcID, call llvm.Value, createCall func(funcPtr llvm.Value, params []llvm.Value) llvm.Value, functions funcWithUsesList) {
234+
// The block that cannot be reached with correct funcValues (to help the
235+
// optimizer).
236+
c.builder.SetInsertPointBefore(call)
237+
defaultBlock := llvm.AddBasicBlock(call.InstructionParent().Parent(), "func.default")
238+
c.builder.SetInsertPointAtEnd(defaultBlock)
239+
c.builder.CreateUnreachable()
217240

218-
// Split right after the switch. We will need to insert a few
219-
// basic blocks in this gap.
220-
nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
241+
// Create the switch.
242+
c.builder.SetInsertPointBefore(call)
243+
sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
221244

222-
// The 0 case, which is actually a nil check.
223-
nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil")
224-
c.builder.SetInsertPointAtEnd(nilBlock)
225-
c.createRuntimeCall("nilPanic", nil, "")
226-
c.builder.CreateUnreachable()
227-
sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
245+
// Split right after the switch. We will need to insert a few basic blocks
246+
// in this gap.
247+
nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
228248

229-
// Gather the list of parameters for every call we're going to
230-
// make.
231-
callParams := make([]llvm.Value, funcCall.OperandsCount()-1)
232-
for i := range callParams {
233-
callParams[i] = funcCall.Operand(i)
234-
}
249+
// The 0 case, which is actually a nil check.
250+
nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil")
251+
c.builder.SetInsertPointAtEnd(nilBlock)
252+
c.createRuntimeCall("nilPanic", nil, "")
253+
c.builder.CreateUnreachable()
254+
sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
235255

236-
// If the call produces a value, we need to get it using a PHI
237-
// node.
238-
phiBlocks := make([]llvm.BasicBlock, len(functions))
239-
phiValues := make([]llvm.Value, len(functions))
240-
for i, fn := range functions {
241-
// Insert a switch case.
242-
bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
243-
c.builder.SetInsertPointAtEnd(bb)
244-
result := c.builder.CreateCall(fn.funcPtr, callParams, "")
245-
c.builder.CreateBr(nextBlock)
246-
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
247-
phiBlocks[i] = bb
248-
phiValues[i] = result
249-
}
250-
// Create the PHI node so that the call result flows into the
251-
// next block (after the split). This is only necessary when the
252-
// call produced a value.
253-
if funcCall.Type().TypeKind() != llvm.VoidTypeKind {
254-
c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
255-
phi := c.builder.CreatePHI(funcCall.Type(), "")
256-
phi.AddIncoming(phiValues, phiBlocks)
257-
funcCall.ReplaceAllUsesWith(phi)
258-
}
256+
// Gather the list of parameters for every call we're going to make.
257+
callParams := make([]llvm.Value, call.OperandsCount()-1)
258+
for i := range callParams {
259+
callParams[i] = call.Operand(i)
260+
}
259261

260-
// Finally, remove the old instructions.
261-
funcCall.EraseFromParentAsInstruction()
262-
for _, inttoptr := range getUses(getFuncPtrCall) {
263-
inttoptr.EraseFromParentAsInstruction()
264-
}
265-
getFuncPtrCall.EraseFromParentAsInstruction()
266-
}
267-
}
262+
// If the call produces a value, we need to get it using a PHI
263+
// node.
264+
phiBlocks := make([]llvm.BasicBlock, len(functions))
265+
phiValues := make([]llvm.Value, len(functions))
266+
for i, fn := range functions {
267+
// Insert a switch case.
268+
bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
269+
c.builder.SetInsertPointAtEnd(bb)
270+
result := createCall(fn.funcPtr, callParams)
271+
c.builder.CreateBr(nextBlock)
272+
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
273+
phiBlocks[i] = bb
274+
phiValues[i] = result
275+
}
276+
// Create the PHI node so that the call result flows into the
277+
// next block (after the split). This is only necessary when the
278+
// call produced a value.
279+
if call.Type().TypeKind() != llvm.VoidTypeKind {
280+
c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
281+
phi := c.builder.CreatePHI(call.Type(), "")
282+
phi.AddIncoming(phiValues, phiBlocks)
283+
call.ReplaceAllUsesWith(phi)
268284
}
269285
}

compiler/func.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,15 @@ const (
3232
// funcImplementation picks an appropriate func value implementation for the
3333
// target.
3434
func (c *Compiler) funcImplementation() funcValueImplementation {
35-
if c.GOARCH == "wasm" {
35+
// Always pick the switch implementation, as it allows the use of blocking
36+
// inside a function that is used as a func value.
37+
switch c.selectScheduler() {
38+
case "coroutines":
3639
return funcValueSwitch
37-
} else {
40+
case "tasks":
3841
return funcValueDoubleword
42+
default:
43+
panic("unknown scheduler type")
3944
}
4045
}
4146

0 commit comments

Comments
 (0)