Skip to content

transform (coroutines): fix memory corruption for tail calls that reference stack allocations #2117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions transform/coroutines.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,11 @@ func (c *coroutineLoweringPass) lowerFuncsPass() {
continue
}

if len(fn.normalCalls) == 0 {
// No suspend points. Lower without turning it into a coroutine.
if len(fn.normalCalls) == 0 && fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think alloca instructions are always at the start of the entry block. I think it would be safer to check the entire entry block for alloca instructions, just in case some are not the first instruction.

(Technically they can be anywhere in the function but we check in other places that this isn't possible).

Copy link
Member Author

@niaow niaow Sep 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are any alloca instructions that are not at the start of the entry block, the coroutine lowering pass crashes iirc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC it crashes if they are not in the entry block. but I think they are allowed if they are not directly at the start of the entry block.

// No suspend points or stack allocations. Lower without turning it into a coroutine.
c.lowerFuncFast(fn)
} else {
// There are suspend points, so it is necessary to turn this into a coroutine.
// There are suspend points or stack allocations, so it is necessary to turn this into a coroutine.
c.lowerFuncCoro(fn)
}
}
Expand Down Expand Up @@ -763,6 +763,27 @@ func (c *coroutineLoweringPass) lowerCallReturn(caller *asyncFunc, call llvm.Val
// lowerFuncCoro transforms an async function into a coroutine by lowering async operations to `llvm.coro` intrinsics.
// See https://llvm.org/docs/Coroutines.html for more information on these intrinsics.
func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
// Ensure that any alloca instructions in the entry block are at the start.
// Otherwise, block splitting would result in unintended behavior.
{
// Skip alloca instructions at the start of the block.
inst := fn.fn.FirstBasicBlock().FirstInstruction()
for !inst.IsAAllocaInst().IsNil() {
inst = llvm.NextInstruction(inst)
}
Comment on lines +771 to +773
Copy link
Member

@aykevl aykevl Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this won't work correctly if there are no alloca instructions in the entry block: it will eventually reach the last instruction in the block and crash on a nil pointer dereference.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop condition is the other way around. If there are no alloca instructions this will run 0 times.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. I got confused by the double negation.


// Find any other alloca instructions and move them after the other allocas.
c.builder.SetInsertPointBefore(inst)
for !inst.IsNil() {
next := llvm.NextInstruction(inst)
if !inst.IsAAllocaInst().IsNil() {
inst.RemoveFromParentAsInstruction()
c.builder.Insert(inst)
}
inst = next
}
Comment on lines +775 to +784
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This moves alloca instructions together, which I guess is fine, but I don't see how this addresses the issue I raised before?

I don't think alloca instructions are always at the start of the entry block. I think it would be safer to check the entire entry block for alloca instructions, just in case some are not the first instruction.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This moves all alloca instructions to the start of the entry block, so later the check does correctly apply to all allocas in the entry block. This also fixes the potential issues where an alloca could be moved by a later transformation into another block by SplitBasicBlock.

}

returnType := fn.fn.Type().ElementType().ReturnType()

// Prepare coroutine state.
Expand Down Expand Up @@ -827,6 +848,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
}

// Lower returns.
var postTail llvm.BasicBlock
for _, ret := range fn.returns {
// Get terminator instruction.
terminator := ret.block.LastInstruction()
Expand Down Expand Up @@ -886,10 +908,37 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
call.EraseFromParentAsInstruction()
}

// Replace terminator with branch to cleanup.
// Replace terminator with a branch to the exit.
var exit llvm.BasicBlock
if ret.kind == returnNormal || ret.kind == returnVoid || fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

// Exit through the cleanup path.
exit = cleanup
} else {
if postTail.IsNil() {
// Create a path with a suspend that never reawakens.
postTail = c.ctx.AddBasicBlock(fn.fn, "post.tail")
c.builder.SetInsertPointAtEnd(postTail)
// %coro.save = call token @llvm.coro.save(i8* %coro.state)
save := c.builder.CreateCall(c.coroSave, []llvm.Value{coroState}, "coro.save")
// %call.suspend = llvm.coro.suspend(token %coro.save, i1 false)
// switch i8 %call.suspend, label %suspend [i8 0, label %wakeup
// i8 1, label %cleanup]
suspendValue := c.builder.CreateCall(c.coroSuspend, []llvm.Value{save, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "call.suspend")
sw := c.builder.CreateSwitch(suspendValue, suspend, 2)
unreachableBlock := c.ctx.AddBasicBlock(fn.fn, "unreachable")
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), unreachableBlock)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), cleanup)
c.builder.SetInsertPointAtEnd(unreachableBlock)
c.builder.CreateUnreachable()
}

// Exit through a permanent suspend.
exit = postTail
}

terminator.EraseFromParentAsInstruction()
c.builder.SetInsertPointAtEnd(ret.block)
c.builder.CreateBr(cleanup)
c.builder.CreateBr(exit)
}

// Lower regular calls.
Expand Down
34 changes: 33 additions & 1 deletion transform/testdata/coroutines.ll
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,43 @@ entry:
}

; Normal function which should not be transformed.
define void @doNothing(i8*, i8*) {
define void @doNothing(i8*, i8* %parentHandle) {
entry:
ret void
}

; Regression test: ensure that a tail call does not destroy the frame while it is still in use.
; Previously, the tail-call lowering transform would branch to the cleanup block after usePtr.
; This caused the lifetime of %a to be incorrectly reduced, and allowed the coroutine lowering transform to keep %a on the stack.
; After a suspend %a would be used, resulting in memory corruption.
define i8 @coroutineTailRegression(i8*, i8* %parentHandle) {
entry:
%a = alloca i8
store i8 5, i8* %a
%val = call i8 @usePtr(i8* %a, i8* undef, i8* null)
ret i8 %val
}

; Regression test: ensure that stack allocations alive during a suspend end up on the heap.
; This used to not be transformed to a coroutine, keeping %a on the stack.
; After a suspend %a would be used, resulting in memory corruption.
define i8 @allocaTailRegression(i8*, i8* %parentHandle) {
entry:
%a = alloca i8
call void @sleep(i64 1000000, i8* undef, i8* null)
store i8 5, i8* %a
%val = call i8 @usePtr(i8* %a, i8* undef, i8* null)
ret i8 %val
}

; usePtr uses a pointer after a suspend.
define i8 @usePtr(i8*, i8*, i8* %parentHandle) {
entry:
call void @sleep(i64 1000000, i8* undef, i8* null)
%val = load i8, i8* %0
ret i8 %val
}

; Goroutine that sleeps and does nothing.
; Should be a void tail call.
define void @sleepGoroutine(i8*, i8* %parentHandle) {
Expand Down
128 changes: 122 additions & 6 deletions transform/testdata/coroutines.out.ll
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ entry:
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
%ret.ptr.bitcast = bitcast i8* %ret.ptr to i32*
store i32 %0, i32* %ret.ptr.bitcast
store i32 %0, i32* %ret.ptr.bitcast, align 4
call void @sleep(i64 %1, i8* undef, i8* %parentHandle)
ret i32 undef
}
Expand Down Expand Up @@ -84,7 +84,7 @@ entry:
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
%ret.ptr.bitcast = bitcast i8* %ret.ptr to i32*
store i32 %0, i32* %ret.ptr.bitcast
store i32 %0, i32* %ret.ptr.bitcast, align 4
%ret.alternate = call i8* @runtime.alloc(i32 4, i8* undef, i8* undef)
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %ret.alternate, i8* undef, i8* undef)
%4 = call i32 @delayedValue(i32 %1, i64 %2, i8* undef, i8* %parentHandle)
Expand All @@ -93,7 +93,7 @@ entry:

define i1 @coroutine(i32 %0, i64 %1, i8* %2, i8* %parentHandle) {
entry:
%call.return = alloca i32
%call.return = alloca i32, align 4
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
Expand All @@ -116,10 +116,10 @@ entry:
]

wakeup: ; preds = %entry
%4 = load i32, i32* %call.return
%4 = load i32, i32* %call.return, align 4
call void @llvm.lifetime.end.p0i8(i64 4, i8* %call.return.bitcast)
%5 = icmp eq i32 %4, 0
store i1 %5, i1* %task.retPtr.bitcast
store i1 %5, i1* %task.retPtr.bitcast, align 1
call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current2, i8* %task.state.parent, i8* undef, i8* undef)
br label %cleanup

Expand All @@ -133,11 +133,127 @@ cleanup: ; preds = %entry, %wakeup
br label %suspend
}

define void @doNothing(i8* %0, i8* %1) {
define void @doNothing(i8* %0, i8* %parentHandle) {
entry:
ret void
}

define i8 @coroutineTailRegression(i8* %0, i8* %parentHandle) {
entry:
%a = alloca i8, align 1
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
store i8 5, i8* %a, align 1
%coro.state.restore = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef)
%val = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle)
br label %post.tail

suspend: ; preds = %post.tail, %cleanup
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
ret i8 undef

cleanup: ; preds = %post.tail
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
br label %suspend

post.tail: ; preds = %entry
%coro.save = call token @llvm.coro.save(i8* %coro.state)
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
switch i8 %call.suspend, label %suspend [
i8 0, label %unreachable
i8 1, label %cleanup
]

unreachable: ; preds = %post.tail
unreachable
}

define i8 @allocaTailRegression(i8* %0, i8* %parentHandle) {
entry:
%a = alloca i8, align 1
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)
%coro.save1 = call token @llvm.coro.save(i8* %coro.state)
%call.suspend2 = call i8 @llvm.coro.suspend(token %coro.save1, i1 false)
switch i8 %call.suspend2, label %suspend [
i8 0, label %wakeup
i8 1, label %cleanup
]

wakeup: ; preds = %entry
store i8 5, i8* %a, align 1
%1 = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef)
%2 = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle)
br label %post.tail

suspend: ; preds = %entry, %post.tail, %cleanup
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
ret i8 undef

cleanup: ; preds = %entry, %post.tail
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
br label %suspend

post.tail: ; preds = %wakeup
%coro.save = call token @llvm.coro.save(i8* %coro.state)
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
switch i8 %call.suspend, label %suspend [
i8 0, label %unreachable
i8 1, label %cleanup
]

unreachable: ; preds = %post.tail
unreachable
}

define i8 @usePtr(i8* %0, i8* %1, i8* %parentHandle) {
entry:
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)
%coro.save = call token @llvm.coro.save(i8* %coro.state)
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
switch i8 %call.suspend, label %suspend [
i8 0, label %wakeup
i8 1, label %cleanup
]

wakeup: ; preds = %entry
%2 = load i8, i8* %0, align 1
store i8 %2, i8* %task.retPtr, align 1
call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
br label %cleanup

suspend: ; preds = %entry, %cleanup
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
ret i8 undef

cleanup: ; preds = %entry, %wakeup
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
br label %suspend
}

define void @sleepGoroutine(i8* %0, i8* %parentHandle) {
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)
Expand Down