diff --git a/compiler/channel.go b/compiler/channel.go index 8098ea99a5..144297e84c 100644 --- a/compiler/channel.go +++ b/compiler/channel.go @@ -4,6 +4,7 @@ package compiler // or pseudo-operations that are lowered during goroutine lowering. import ( + "fmt" "go/types" "golang.org/x/tools/go/ssa" @@ -12,11 +13,22 @@ import ( // emitMakeChan returns a new channel value for the given channel type. func (c *Compiler) emitMakeChan(expr *ssa.MakeChan) (llvm.Value, error) { - chanType := c.getLLVMType(c.getRuntimeType("channel")) - size := c.targetData.TypeAllocSize(chanType) + chanType := c.getLLVMType(expr.Type()) + size := c.targetData.TypeAllocSize(chanType.ElementType()) sizeValue := llvm.ConstInt(c.uintptrType, size, false) ptr := c.createRuntimeCall("alloc", []llvm.Value{sizeValue}, "chan.alloc") - ptr = c.builder.CreateBitCast(ptr, llvm.PointerType(chanType, 0), "chan") + ptr = c.builder.CreateBitCast(ptr, chanType, "chan") + // Set the elementSize field + elementSizePtr := c.builder.CreateGEP(ptr, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + }, "") + elementSize := c.targetData.TypeAllocSize(c.getLLVMType(expr.Type().(*types.Chan).Elem())) + if elementSize > 0xffff { + return ptr, c.makeError(expr.Pos(), fmt.Sprintf("element size is %d bytes, which is bigger than the maximum of %d bytes", elementSize, 0xffff)) + } + elementSizeValue := llvm.ConstInt(c.ctx.Int16Type(), elementSize, false) + c.builder.CreateStore(elementSizeValue, elementSizePtr) return ptr, nil } @@ -33,8 +45,7 @@ func (c *Compiler) emitChanSend(frame *Frame, instr *ssa.Send) { // Do the send. coroutine := c.createRuntimeCall("getCoroutine", nil, "") - valueSize := llvm.ConstInt(c.uintptrType, c.targetData.TypeAllocSize(chanValue.Type()), false) - c.createRuntimeCall("chanSend", []llvm.Value{coroutine, ch, valueAllocaCast, valueSize}, "") + c.createRuntimeCall("chanSend", []llvm.Value{coroutine, ch, valueAllocaCast}, "") // End the lifetime of the alloca. // This also works around a bug in CoroSplit, at least in LLVM 8: @@ -53,8 +64,7 @@ func (c *Compiler) emitChanRecv(frame *Frame, unop *ssa.UnOp) llvm.Value { // Do the receive. coroutine := c.createRuntimeCall("getCoroutine", nil, "") - valueSize := llvm.ConstInt(c.uintptrType, c.targetData.TypeAllocSize(valueType), false) - c.createRuntimeCall("chanRecv", []llvm.Value{coroutine, ch, valueAllocaCast, valueSize}, "") + c.createRuntimeCall("chanRecv", []llvm.Value{coroutine, ch, valueAllocaCast}, "") received := c.builder.CreateLoad(valueAlloca, "chan.received") c.emitLifetimeEnd(valueAllocaCast, valueAllocaSize) @@ -72,8 +82,157 @@ func (c *Compiler) emitChanRecv(frame *Frame, unop *ssa.UnOp) llvm.Value { // emitChanClose closes the given channel. func (c *Compiler) emitChanClose(frame *Frame, param ssa.Value) { - valueType := c.getLLVMType(param.Type().(*types.Chan).Elem()) - valueSize := llvm.ConstInt(c.uintptrType, c.targetData.TypeAllocSize(valueType), false) ch := c.getValue(frame, param) - c.createRuntimeCall("chanClose", []llvm.Value{ch, valueSize}, "") + c.createRuntimeCall("chanClose", []llvm.Value{ch}, "") +} + +// emitSelect emits all IR necessary for a select statements. That's a +// non-trivial amount of code because select is very complex to implement. +func (c *Compiler) emitSelect(frame *Frame, expr *ssa.Select) llvm.Value { + if len(expr.States) == 0 { + // Shortcuts for some simple selects. + llvmType := c.getLLVMType(expr.Type()) + if expr.Blocking { + // Blocks forever: + // select {} + c.createRuntimeCall("deadlockStub", nil, "") + return llvm.Undef(llvmType) + } else { + // No-op: + // select { + // default: + // } + retval := llvm.Undef(llvmType) + retval = c.builder.CreateInsertValue(retval, llvm.ConstInt(c.intType, 0xffffffffffffffff, true), 0, "") + return retval // {-1, false} + } + } + + // This code create a (stack-allocated) slice containing all the select + // cases and then calls runtime.chanSelect to perform the actual select + // statement. + // Simple selects (blocking and with just one case) are already transformed + // into regular chan operations during SSA construction so we don't have to + // optimize such small selects. + + // Go through all the cases. Create the selectStates slice and and + // determine the receive buffer size and alignment. + recvbufSize := uint64(0) + recvbufAlign := 0 + hasReceives := false + var selectStates []llvm.Value + chanSelectStateType := c.getLLVMRuntimeType("chanSelectState") + for _, state := range expr.States { + ch := c.getValue(frame, state.Chan) + selectState := c.getZeroValue(chanSelectStateType) + selectState = c.builder.CreateInsertValue(selectState, ch, 0, "") + switch state.Dir { + case types.RecvOnly: + // Make sure the receive buffer is big enough and has the correct alignment. + llvmType := c.getLLVMType(state.Chan.Type().(*types.Chan).Elem()) + if size := c.targetData.TypeAllocSize(llvmType); size > recvbufSize { + recvbufSize = size + } + if align := c.targetData.ABITypeAlignment(llvmType); align > recvbufAlign { + recvbufAlign = align + } + hasReceives = true + case types.SendOnly: + // Store this value in an alloca and put a pointer to this alloca + // in the send state. + sendValue := c.getValue(frame, state.Send) + alloca := c.createEntryBlockAlloca(sendValue.Type(), "select.send.value") + c.builder.CreateStore(sendValue, alloca) + ptr := c.builder.CreateBitCast(alloca, c.i8ptrType, "") + selectState = c.builder.CreateInsertValue(selectState, ptr, 1, "") + default: + panic("unreachable") + } + selectStates = append(selectStates, selectState) + } + + // Create a receive buffer, where the received value will be stored. + recvbuf := llvm.Undef(c.i8ptrType) + if hasReceives { + allocaType := llvm.ArrayType(c.ctx.Int8Type(), int(recvbufSize)) + recvbufAlloca := c.builder.CreateAlloca(allocaType, "select.recvbuf.alloca") + recvbufAlloca.SetAlignment(recvbufAlign) + recvbuf = c.builder.CreateGEP(recvbufAlloca, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + }, "select.recvbuf") + } + + // Create the states slice (allocated on the stack). + statesAllocaType := llvm.ArrayType(chanSelectStateType, len(selectStates)) + statesAlloca := c.builder.CreateAlloca(statesAllocaType, "select.states.alloca") + for i, state := range selectStates { + // Set each slice element to the appropriate channel. + gep := c.builder.CreateGEP(statesAlloca, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false), + }, "") + c.builder.CreateStore(state, gep) + } + statesPtr := c.builder.CreateGEP(statesAlloca, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + }, "select.states") + statesLen := llvm.ConstInt(c.uintptrType, uint64(len(selectStates)), false) + + // Convert the 'blocking' flag on this select into a LLVM value. + blockingInt := uint64(0) + if expr.Blocking { + blockingInt = 1 + } + blockingValue := llvm.ConstInt(c.ctx.Int1Type(), blockingInt, false) + + // Do the select in the runtime. + results := c.createRuntimeCall("chanSelect", []llvm.Value{ + recvbuf, + statesPtr, statesLen, statesLen, // []chanSelectState + blockingValue, + }, "") + + // The result value does not include all the possible received values, + // because we can't load them in advance. Instead, the *ssa.Extract + // instruction will treat a *ssa.Select specially and load it there inline. + // Store the receive alloca in a sidetable until we hit this extract + // instruction. + if frame.selectRecvBuf == nil { + frame.selectRecvBuf = make(map[*ssa.Select]llvm.Value) + } + frame.selectRecvBuf[expr] = recvbuf + + return results +} + +// getChanSelectResult returns the special values from a *ssa.Extract expression +// when extracting a value from a select statement (*ssa.Select). Because +// *ssa.Select cannot load all values in advance, it does this later in the +// *ssa.Extract expression. +func (c *Compiler) getChanSelectResult(frame *Frame, expr *ssa.Extract) llvm.Value { + if expr.Index == 0 { + // index + value := c.getValue(frame, expr.Tuple) + index := c.builder.CreateExtractValue(value, expr.Index, "") + if index.Type().IntTypeWidth() < c.intType.IntTypeWidth() { + index = c.builder.CreateSExt(index, c.intType, "") + } + return index + } else if expr.Index == 1 { + // comma-ok + value := c.getValue(frame, expr.Tuple) + return c.builder.CreateExtractValue(value, expr.Index, "") + } else { + // Select statements are (index, ok, ...) where ... is a number of + // received values, depending on how many receive statements there + // are. They are all combined into one alloca (because only one + // receive can proceed at a time) so we'll get that alloca, bitcast + // it to the correct type, and dereference it. + recvbuf := frame.selectRecvBuf[expr.Tuple.(*ssa.Select)] + typ := llvm.PointerType(c.getLLVMType(expr.Type()), 0) + ptr := c.builder.CreateBitCast(recvbuf, typ, "") + return c.builder.CreateLoad(ptr, "") + } } diff --git a/compiler/compiler.go b/compiler/compiler.go index f27b8adee9..dd7e3d5eb9 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -84,6 +84,7 @@ type Frame struct { deferFuncs map[*ir.Function]int deferInvokeFuncs map[string]int deferClosureFuncs map[*ir.Function]int + selectRecvBuf map[*ssa.Select]llvm.Value } type Phi struct { @@ -1445,9 +1446,11 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { x := c.getValue(frame, expr.X) return c.parseConvert(expr.X.Type(), expr.Type(), x, expr.Pos()) case *ssa.Extract: + if _, ok := expr.Tuple.(*ssa.Select); ok { + return c.getChanSelectResult(frame, expr), nil + } value := c.getValue(frame, expr.Tuple) - result := c.builder.CreateExtractValue(value, expr.Index, "") - return result, nil + return c.builder.CreateExtractValue(value, expr.Index, ""), nil case *ssa.Field: value := c.getValue(frame, expr.X) if s := expr.X.Type().Underlying().(*types.Struct); s.NumFields() > 2 && s.Field(0).Name() == "C union" { @@ -1696,25 +1699,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { c.builder.CreateStore(c.getZeroValue(iteratorType), it) return it, nil case *ssa.Select: - if len(expr.States) == 0 { - // Shortcuts for some simple selects. - llvmType := c.getLLVMType(expr.Type()) - if expr.Blocking { - // Blocks forever: - // select {} - c.createRuntimeCall("deadlockStub", nil, "") - return llvm.Undef(llvmType), nil - } else { - // No-op: - // select { - // default: - // } - retval := llvm.Undef(llvmType) - retval = c.builder.CreateInsertValue(retval, llvm.ConstInt(c.intType, 0xffffffffffffffff, true), 0, "") - return retval, nil // {-1, false} - } - } - return llvm.Undef(c.getLLVMType(expr.Type())), c.makeError(expr.Pos(), "unimplemented: "+expr.String()) + return c.emitSelect(frame, expr), nil case *ssa.Slice: if expr.Max != nil { return llvm.Value{}, c.makeError(expr.Pos(), "todo: full slice expressions (with max): "+expr.Type().String()) diff --git a/src/runtime/chan.go b/src/runtime/chan.go index 60c6905072..578adf2947 100644 --- a/src/runtime/chan.go +++ b/src/runtime/chan.go @@ -28,24 +28,35 @@ import ( ) type channel struct { - state uint8 - blocked *coroutine + elementSize uint16 // the size of one value in this channel + state chanState + blocked *coroutine } +type chanState uint8 + const ( - chanStateEmpty = iota + chanStateEmpty chanState = iota chanStateRecv chanStateSend chanStateClosed ) +// chanSelectState is a single channel operation (send/recv) in a select +// statement. The value pointer is either nil (for receives) or points to the +// value to send (for sends). +type chanSelectState struct { + ch *channel + value unsafe.Pointer +} + func deadlockStub() // chanSend sends a single value over the channel. If this operation can // complete immediately (there is a goroutine waiting for a value), it sends the // value and re-activates both goroutines. If not, it sets itself as waiting on // a value. -func chanSend(sender *coroutine, ch *channel, value unsafe.Pointer, size uintptr) { +func chanSend(sender *coroutine, ch *channel, value unsafe.Pointer) { if ch == nil { // A nil channel blocks forever. Do not scheduler this goroutine again. return @@ -58,7 +69,7 @@ func chanSend(sender *coroutine, ch *channel, value unsafe.Pointer, size uintptr case chanStateRecv: receiver := ch.blocked receiverPromise := receiver.promise() - memcpy(receiverPromise.ptr, value, size) + memcpy(receiverPromise.ptr, value, uintptr(ch.elementSize)) receiverPromise.data = 1 // commaOk = true ch.blocked = receiverPromise.next receiverPromise.next = nil @@ -80,7 +91,7 @@ func chanSend(sender *coroutine, ch *channel, value unsafe.Pointer, size uintptr // sender, it receives the value immediately and re-activates both coroutines. // If not, it sets itself as available for receiving. If the channel is closed, // it immediately activates itself with a zero value as the result. -func chanRecv(receiver *coroutine, ch *channel, value unsafe.Pointer, size uintptr) { +func chanRecv(receiver *coroutine, ch *channel, value unsafe.Pointer) { if ch == nil { // A nil channel blocks forever. Do not scheduler this goroutine again. return @@ -89,7 +100,7 @@ func chanRecv(receiver *coroutine, ch *channel, value unsafe.Pointer, size uintp case chanStateSend: sender := ch.blocked senderPromise := sender.promise() - memcpy(value, senderPromise.ptr, size) + memcpy(value, senderPromise.ptr, uintptr(ch.elementSize)) receiver.promise().data = 1 // commaOk = true ch.blocked = senderPromise.next senderPromise.next = nil @@ -103,7 +114,7 @@ func chanRecv(receiver *coroutine, ch *channel, value unsafe.Pointer, size uintp ch.state = chanStateRecv ch.blocked = receiver case chanStateClosed: - memzero(value, size) + memzero(value, uintptr(ch.elementSize)) receiver.promise().data = 0 // commaOk = false activateTask(receiver) case chanStateRecv: @@ -115,7 +126,7 @@ func chanRecv(receiver *coroutine, ch *channel, value unsafe.Pointer, size uintp // chanClose closes the given channel. If this channel has a receiver or is // empty, it closes the channel. Else, it panics. -func chanClose(ch *channel, size uintptr) { +func chanClose(ch *channel) { if ch == nil { // Not allowed by the language spec. runtimePanic("close of nil channel") @@ -133,7 +144,7 @@ func chanClose(ch *channel, size uintptr) { case chanStateRecv: // The receiver must be re-activated with a zero value. receiverPromise := ch.blocked.promise() - memzero(receiverPromise.ptr, size) + memzero(receiverPromise.ptr, uintptr(ch.elementSize)) receiverPromise.data = 0 // commaOk = false activateTask(ch.blocked) ch.state = chanStateClosed @@ -143,3 +154,63 @@ func chanClose(ch *channel, size uintptr) { ch.state = chanStateClosed } } + +// chanSelect is the runtime implementation of the select statement. This is +// perhaps the most complicated statement in the Go spec. It returns the +// selected index and the 'comma-ok' value. +// +// TODO: do this in a round-robin fashion (as specified in the Go spec) instead +// of picking the first one that can proceed. +func chanSelect(recvbuf unsafe.Pointer, states []chanSelectState, blocking bool) (uintptr, bool) { + // See whether we can receive from one of the channels. + for i, state := range states { + if state.ch == nil { + // A nil channel blocks forever, so don't consider it here. + continue + } + if state.value == nil { + // A receive operation. + switch state.ch.state { + case chanStateSend: + // We can receive immediately. + sender := state.ch.blocked + senderPromise := sender.promise() + memcpy(recvbuf, senderPromise.ptr, uintptr(state.ch.elementSize)) + state.ch.blocked = senderPromise.next + senderPromise.next = nil + activateTask(sender) + if state.ch.blocked == nil { + state.ch.state = chanStateEmpty + } + return uintptr(i), true // commaOk = true + case chanStateClosed: + // Receive the zero value. + memzero(recvbuf, uintptr(state.ch.elementSize)) + return uintptr(i), false // commaOk = false + } + } else { + // A send operation: state.value is not nil. + switch state.ch.state { + case chanStateRecv: + receiver := state.ch.blocked + receiverPromise := receiver.promise() + memcpy(receiverPromise.ptr, state.value, uintptr(state.ch.elementSize)) + receiverPromise.data = 1 // commaOk = true + state.ch.blocked = receiverPromise.next + receiverPromise.next = nil + activateTask(receiver) + if state.ch.blocked == nil { + state.ch.state = chanStateEmpty + } + return uintptr(i), false + case chanStateClosed: + runtimePanic("send on closed channel") + } + } + } + + if !blocking { + return ^uintptr(0), false + } + panic("unimplemented: blocking select") +} diff --git a/testdata/channel.go b/testdata/channel.go index cdf3fac6e6..db5b3f2972 100644 --- a/testdata/channel.go +++ b/testdata/channel.go @@ -48,10 +48,63 @@ func main() { } println("sum(100):", sum) - // Test select + // Test simple selects. go selectDeadlock() go selectNoOp() + // Test select with a single send operation (transformed into chan send). + ch = make(chan int) + go fastreceiver(ch) + select { + case ch <- 5: + println("select one sent") + } + close(ch) + + // Test select with a single recv operation (transformed into chan recv). + select { + case n := <-ch: + println("select one n:", n) + } + + // Test select recv with channel that has one entry. + ch = make(chan int) + go func(ch chan int) { + ch <- 55 + }(ch) + time.Sleep(time.Millisecond) + select { + case make(chan int) <- 3: + println("unreachable") + case n := <-ch: + println("select n from chan:", n) + case n := <-make(chan int): + println("unreachable:", n) + } + + // Test select recv with closed channel. + close(ch) + select { + case make(chan int) <- 3: + println("unreachable") + case n := <-ch: + println("select n from closed chan:", n) + case n := <-make(chan int): + println("unreachable:", n) + } + + // Test select send. + ch = make(chan int) + go fastreceiver(ch) + time.Sleep(time.Millisecond) + select { + case ch <- 235: + println("select send") + case n := <-make(chan int): + println("unreachable:", n) + } + close(ch) + // Allow goroutines to exit. time.Sleep(time.Microsecond) } @@ -68,7 +121,7 @@ func sender(ch chan int) { } func sendComplex(ch chan complex128) { - ch <- 7+10.5i + ch <- 7 + 10.5i } func fastsender(ch chan int) { diff --git a/testdata/channel.txt b/testdata/channel.txt index 9daef1be5e..2355bd53cc 100644 --- a/testdata/channel.txt +++ b/testdata/channel.txt @@ -23,3 +23,10 @@ sum(100): 4950 deadlocking select no-op after no-op +select one sent +sum: 5 +select one n: 0 +select n from chan: 55 +select n from closed chan: 0 +select send +sum: 235