Skip to content

Add kernel.Task.BlockFD[WithDeadline]. #8044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nogo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ analyzers:
- "linkname to unknown symbol"
exclude:
- ".*/containerd/sys/subprocess_unsafe_linux.go"
SA1017: # Channels used with os/signal.Notify should be buffered.
internal:
exclude:
- pkg/sentry/kernel/signal.go # Intentional.
SA1019: # Use of deprecated identifier.
# disable for now due to misattribution from golang.org/issue/44195.
generated:
Expand Down
24 changes: 24 additions & 0 deletions pkg/gohacks/gohacks_unsafe.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,30 @@ func StringFromImmutableBytes(bs []byte) string {
// Note that go:linkname silently doesn't work if the local name is exported,
// necessitating an indirection for exported functions.

// EnterSyscall is runtime.entersyscall.
//
// WARNING: It is unsafe to call any functions between runtime.entersyscall
// and runtime.exitsyscall unless they are defined in the runtime package or in
// assembly, for reasons explained by a comment in syscall.Syscall.
//
//go:nosplit
func EnterSyscall() {
entersyscall()
}

//go:linkname entersyscall runtime.entersyscall
func entersyscall()

// ExitSyscall is runtime.exitsyscall.
//
//go:nosplit
func ExitSyscall() {
exitsyscall()
}

//go:linkname exitsyscall runtime.exitsyscall
func exitsyscall()

// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>.
//
//go:nosplit
Expand Down
18 changes: 18 additions & 0 deletions pkg/sentry/kernel/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,19 @@ go_template_instance(
},
)

go_template_instance(
name = "tidset_atomicptrmap",
out = "tidset_atomicptrmap_unsafe.go",
package = "kernel",
prefix = "tidset",
template = "//pkg/sync/atomicptrmap:generic_atomicptrmap",
types = {
"Key": "int32",
"Value": "tidsetValue",
"Hasher": "tidsetHasher",
},
)

go_template_instance(
name = "fd_table_refs",
out = "fd_table_refs.go",
Expand Down Expand Up @@ -234,13 +247,15 @@ go_library(
"signal.go",
"signal_handlers.go",
"signal_handlers_mutex.go",
"signal_unsafe.go",
"socket_list.go",
"syscalls.go",
"syscalls_state.go",
"syslog.go",
"task.go",
"task_acct.go",
"task_block.go",
"task_block_unsafe.go",
"task_cgroup.go",
"task_clone.go",
"task_context.go",
Expand All @@ -266,6 +281,7 @@ go_library(
"thread_group_timer_mutex.go",
"threads.go",
"threads_impl.go",
"tidset_atomicptrmap_unsafe.go",
"timekeeper.go",
"timekeeper_state.go",
"tty.go",
Expand Down Expand Up @@ -298,12 +314,14 @@ go_library(
"//pkg/errors/linuxerr",
"//pkg/eventchannel",
"//pkg/fspath",
"//pkg/gohacks",
"//pkg/goid",
"//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
"//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/procid",
"//pkg/refs",
"//pkg/refsvfs2",
"//pkg/safemem",
Expand Down
78 changes: 78 additions & 0 deletions pkg/sentry/kernel/signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ package kernel

import (
"fmt"
"os"
"os/signal"

"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
)

// SignalPanic is used to panic the running threads. It is a signal which
Expand All @@ -43,6 +47,9 @@ func (k *Kernel) sendExternalSignal(info *linux.SignalInfo, context string) {
case platform.SignalInterrupt:
// Assume that a call to platform.Context.Interrupt() misfired.

case SignalInterruptSyscall:
// Expected.

case SignalPanic:
// SignalPanic is also specially handled in sentry setup to ensure that
// it causes a panic even after tasks exit, but SignalPanic may also
Expand Down Expand Up @@ -76,3 +83,74 @@ func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *linux.SignalInf
info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
return info
}

// SignalInterruptSyscall is sent by Task.interrupt() to task goroutine threads
// in host syscalls that have atomically unmasked SignalInterruptSyscall.
// Threads whose IDs may be stored in Task.syscallTID have
// SignalInterruptSyscall masked when not in said syscalls, ensuring that
// signal delivery does not occur unless it would interrupt a syscall.
const SignalInterruptSyscall = linux.SIGPWR

var (
// interruptSyscallReadyTIDs stores thread IDs for which
// interruptibleSyscallSignalMask has already been called. (This is feasible
// because the Go runtime never destroys threads, so thread IDs are never
// reused.)
interruptSyscallReadyTIDs tidsetAtomicPtrMap

interruptSyscallInitOnce sync.Once
interruptSyscallSigmask linux.SignalSet // SignalInterruptSyscall is unmasked
)

// interruptibleSyscallSignalMask ensures that SignalInterruptSyscall is
// masked by the calling thread, then returns a signal mask not containing
// SignalInterruptSyscall, for use by the interruptible syscall.
//
// Preconditions:
// - runtime.LockOSThread() is in effect.
// - tid is the caller's thread ID.
func interruptibleSyscallSignalMask(tid int32) linux.SignalSet {
if interruptSyscallReadyTIDs.Load(tid) != nil {
return interruptSyscallSigmask
}
interruptSyscallInitOnce.Do(func() {
// Get the current signal mask, assuming that it's the correct signal mask
// for all threads on which task goroutines can run.
var sigmask linux.SignalSet
if err := sigprocmask(0, nil, &sigmask); err != nil {
panic(fmt.Sprintf("sigprocmask(0, nil, %p) failed: %v", &sigmask, err))
}
interruptSyscallSigmask = sigmask &^ linux.SignalSetOf(SignalInterruptSyscall)
// SignalInterruptSyscall must be handled by a userspace signal handler to
// prevent ppoll(2) from being automatically restarted. The easiest way to
// ensure this is to require Go to install a signal handler.
// signal.Notify() will perform a non-blocking send to whatever channel we
// provide, and we don't actually care about being notified about the
// signal, so pass it an unbuffered channel that will never have a
// receiver.
signal.Notify(make(chan os.Signal), unix.Signal(SignalInterruptSyscall))
})
sigmask := linux.SignalSetOf(SignalInterruptSyscall)
if err := sigprocmask(linux.SIG_BLOCK, &sigmask, nil); err != nil {
panic(fmt.Sprintf("sigprocmask(SIG_BLOCK, %p, nil) failed: %v", &sigmask, err))
}
interruptSyscallReadyTIDs.Store(tid, &tidsetValue{})
return interruptSyscallSigmask
}

type tidsetValue struct{}

type tidsetHasher struct{}

// Init implements generic_atomicptrmap.Hasher.Init.
func (tidsetHasher) Init() {
}

// Hash implements generic_atomicptrmap.Hasher.Hash.
func (tidsetHasher) Hash(tid int32) uintptr {
// This hash function is the linear congruential generator defined as
// nrand48() by POSIX, with the constant addition removed (since, with
// overwhelming probability, it doesn't affect the output due to the bit
// shift).
return uintptr(uint64(tid) * 0x5deece66d >> 16)
}
30 changes: 30 additions & 0 deletions pkg/sentry/kernel/signal_unsafe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package kernel

import (
"unsafe"

"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
)

func sigprocmask(how int32, set, oldset *linux.SignalSet) error {
_, _, errno := unix.RawSyscall6(unix.SYS_RT_SIGPROCMASK, uintptr(how), uintptr(unsafe.Pointer(set)), uintptr(unsafe.Pointer(oldset)), linux.SignalSetSize, 0, 0)
if errno != 0 {
return errno
}
return nil
}
7 changes: 7 additions & 0 deletions pkg/sentry/kernel/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ type Task struct {
// interruptChan is always notified after restore (see Task.run).
interruptChan chan struct{} `state:"nosave"`

// If syscallTID is non-zero, it is the task goroutine's current thread ID,
// and the task goroutine is blocked in ppoll(2) or epoll_pwait(2) with
// SignalInterruptSyscall unmasked.
//
// syscallTID is owned by the task goroutine.
syscallTID atomicbitops.Int32

// gosched contains the current scheduling state of the task goroutine.
//
// gosched is protected by goschedSeq. gosched is owned by the task
Expand Down
71 changes: 68 additions & 3 deletions pkg/sentry/kernel/task_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@
package kernel

import (
"os"
"runtime"
"runtime/trace"
"time"

"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)

var globalTGID = int32(os.Getpid())

// BlockWithTimeout blocks t until an event is received from C, the application
// monotonic clock indicates that timeout has elapsed (only if haveTimeout is true),
// or t is interrupted. It returns:
Expand Down Expand Up @@ -199,6 +205,60 @@ func (t *Task) completeSleep() {
t.Activate()
}

// BlockFD blocks until the given host FD is ready for at least one of the
// given I/O events or t is interrupted. It returns the set of ready events for
// fd.
func (t *Task) BlockFD(fd int32, mask waiter.EventMask) (waiter.EventMask, error) {
pfds := []linux.PollFD{
{
FD: fd,
Events: int16(mask.ToLinux()),
},
}
_, err := t.blockPoll(pfds, nil)
if err != nil {
return 0, err
}
return waiter.EventMaskFromLinux(uint32(pfds[0].REvents)), nil
}

// BlockFDWithDeadline is equivalent to BlockFD, but if haveDeadline is true,
// it returns ETIMEDOUT if the deadline expires before fd becomes ready.
func (t *Task) BlockFDWithDeadline(fd int32, mask waiter.EventMask, haveDeadline bool, deadline ktime.Time) (waiter.EventMask, error) {
if !haveDeadline {
return t.BlockFD(fd, mask)
}

pfds := []linux.PollFD{
{
FD: fd,
Events: int16(mask.ToLinux()),
},
}
var timeout linux.Timespec
if now := t.Kernel().MonotonicClock().Now(); now.Before(deadline) {
timeout = linux.DurationToTimespec(deadline.Sub(now))
}
_, err := t.blockPoll(pfds, &timeout)
if err != nil {
return 0, err
}
return waiter.EventMaskFromLinux(uint32(pfds[0].REvents)), nil
}

func (t *Task) blockPoll(pfds []linux.PollFD, timeout *linux.Timespec) (int, error) {
if sync.RaceEnabled {
t.assertTaskGoroutine()
}

t.prepareSleep()
defer t.completeSleep()
region := trace.StartRegion(t.traceContext, blockRegion)
defer region.End()

return t.blockPollUnsafe(pfds, timeout)
}

// Interrupted implements context.Context.Interrupted.
func (t *Task) Interrupted() bool {
if t.interrupted() {
Expand Down Expand Up @@ -246,6 +306,11 @@ func (t *Task) unsetInterrupted() {
// userspace.
func (t *Task) interrupt() {
t.interruptSelf()
if tid := t.syscallTID.Load(); tid != 0 {
if err := unix.Tgkill(int(globalTGID), int(tid), unix.Signal(SignalInterruptSyscall)); err != nil && err != unix.ESRCH {
log.Warningf("failed to tgkill blocked task goroutine thread %d: %v", tid, err)
}
}
t.p.Interrupt()
}

Expand All @@ -256,9 +321,9 @@ func (t *Task) interruptSelf() {
case t.interruptChan <- struct{}{}:
default:
}
// platform.Context.Interrupt() is unnecessary since a task goroutine
// calling interruptSelf() cannot also be blocked in
// platform.Context.Switch().
// Checking syscallTID and calling platform.Context.Interrupt() are
// unnecessary since a task goroutine calling interruptSelf() cannot also be
// blocked in a host syscall or platform.Context.Switch().
}

// Interrupt implements context.Blocker.Interrupt.
Expand Down
Loading