Skip to content

Commit 520e360

Browse files
laurazardthaJeztah
authored andcommitted
commandconn: don't return error if command closed successfully
--- commandconn: fix race on `Close()` During normal operation, if a `Read()` or `Write()` call results in an EOF, we call `onEOF()` to handle the terminating command, and store it's exit value. However, if a Read/Write call was blocked while `Close()` is called the in/out pipes are immediately closed which causes an EOF to be returned. Here, we shouldn't call `onEOF()`, since the reason why we got an EOF is because we're already terminating the connection. This also prevents a race between two calls to the commands `Wait()`, in the `Close()` call and `onEOF()` --- Add CLI init timeout to SSH connections --- connhelper: add 30s ssh default dialer timeout (same as non-ssh dialer) Signed-off-by: Laura Brehm <[email protected]> (cherry picked from commit a5ebe22) Signed-off-by: Sebastiaan van Stijn <[email protected]>
1 parent fad718c commit 520e360

File tree

5 files changed

+316
-110
lines changed

5 files changed

+316
-110
lines changed

cli/command/cli.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"path/filepath"
99
"runtime"
1010
"strconv"
11-
"strings"
1211
"sync"
1312
"time"
1413

@@ -327,13 +326,8 @@ func (cli *DockerCli) getInitTimeout() time.Duration {
327326

328327
func (cli *DockerCli) initializeFromClient() {
329328
ctx := context.Background()
330-
if !strings.HasPrefix(cli.dockerEndpoint.Host, "ssh://") {
331-
// @FIXME context.WithTimeout doesn't work with connhelper / ssh connections
332-
// time="2020-04-10T10:16:26Z" level=warning msg="commandConn.CloseWrite: commandconn: failed to wait: signal: killed"
333-
var cancel func()
334-
ctx, cancel = context.WithTimeout(ctx, cli.getInitTimeout())
335-
defer cancel()
336-
}
329+
ctx, cancel := context.WithTimeout(ctx, cli.getInitTimeout())
330+
defer cancel()
337331

338332
ping, err := cli.client.Ping(ctx)
339333
if err != nil {

cli/connhelper/commandconn/commandconn.go

Lines changed: 106 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"runtime"
2424
"strings"
2525
"sync"
26+
"sync/atomic"
2627
"syscall"
2728
"time"
2829

@@ -64,100 +65,86 @@ func New(_ context.Context, cmd string, args ...string) (net.Conn, error) {
6465

6566
// commandConn implements net.Conn
6667
type commandConn struct {
67-
cmd *exec.Cmd
68-
cmdExited bool
69-
cmdWaitErr error
70-
cmdMutex sync.Mutex
71-
stdin io.WriteCloser
72-
stdout io.ReadCloser
73-
stderrMu sync.Mutex
74-
stderr bytes.Buffer
75-
stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed
76-
stdinClosed bool
77-
stdoutClosed bool
78-
localAddr net.Addr
79-
remoteAddr net.Addr
68+
cmdMutex sync.Mutex // for cmd, cmdWaitErr
69+
cmd *exec.Cmd
70+
cmdWaitErr error
71+
cmdExited atomic.Bool
72+
stdin io.WriteCloser
73+
stdout io.ReadCloser
74+
stderrMu sync.Mutex // for stderr
75+
stderr bytes.Buffer
76+
stdinClosed atomic.Bool
77+
stdoutClosed atomic.Bool
78+
closing atomic.Bool
79+
localAddr net.Addr
80+
remoteAddr net.Addr
8081
}
8182

82-
// killIfStdioClosed kills the cmd if both stdin and stdout are closed.
83-
func (c *commandConn) killIfStdioClosed() error {
84-
c.stdioClosedMu.Lock()
85-
stdioClosed := c.stdoutClosed && c.stdinClosed
86-
c.stdioClosedMu.Unlock()
87-
if !stdioClosed {
88-
return nil
83+
// kill terminates the process. On Windows it kills the process directly,
84+
// whereas on other platforms, a SIGTERM is sent, before forcefully terminating
85+
// the process after 3 seconds.
86+
func (c *commandConn) kill() {
87+
if c.cmdExited.Load() {
88+
return
8989
}
90-
return c.kill()
91-
}
92-
93-
// killAndWait tries sending SIGTERM to the process before sending SIGKILL.
94-
func killAndWait(cmd *exec.Cmd) error {
90+
c.cmdMutex.Lock()
9591
var werr error
9692
if runtime.GOOS != "windows" {
9793
werrCh := make(chan error)
98-
go func() { werrCh <- cmd.Wait() }()
99-
cmd.Process.Signal(syscall.SIGTERM)
94+
go func() { werrCh <- c.cmd.Wait() }()
95+
_ = c.cmd.Process.Signal(syscall.SIGTERM)
10096
select {
10197
case werr = <-werrCh:
10298
case <-time.After(3 * time.Second):
103-
cmd.Process.Kill()
99+
_ = c.cmd.Process.Kill()
104100
werr = <-werrCh
105101
}
106102
} else {
107-
cmd.Process.Kill()
108-
werr = cmd.Wait()
103+
_ = c.cmd.Process.Kill()
104+
werr = c.cmd.Wait()
109105
}
110-
return werr
106+
c.cmdWaitErr = werr
107+
c.cmdMutex.Unlock()
108+
c.cmdExited.Store(true)
111109
}
112110

113-
// kill returns nil if the command terminated, regardless to the exit status.
114-
func (c *commandConn) kill() error {
115-
var werr error
116-
c.cmdMutex.Lock()
117-
if c.cmdExited {
118-
werr = c.cmdWaitErr
119-
} else {
120-
werr = killAndWait(c.cmd)
121-
c.cmdWaitErr = werr
122-
c.cmdExited = true
123-
}
124-
c.cmdMutex.Unlock()
125-
if werr == nil {
126-
return nil
127-
}
128-
wExitErr, ok := werr.(*exec.ExitError)
129-
if ok {
130-
if wExitErr.ProcessState.Exited() {
131-
return nil
132-
}
111+
// handleEOF handles io.EOF errors while reading or writing from the underlying
112+
// command pipes.
113+
//
114+
// When we've received an EOF we expect that the command will
115+
// be terminated soon. As such, we call Wait() on the command
116+
// and return EOF or the error depending on whether the command
117+
// exited with an error.
118+
//
119+
// If Wait() does not return within 10s, an error is returned
120+
func (c *commandConn) handleEOF(err error) error {
121+
if err != io.EOF {
122+
return err
133123
}
134-
return errors.Wrapf(werr, "commandconn: failed to wait")
135-
}
136124

137-
func (c *commandConn) onEOF(eof error) error {
138-
// when we got EOF, the command is going to be terminated
139-
var werr error
140125
c.cmdMutex.Lock()
141-
if c.cmdExited {
126+
defer c.cmdMutex.Unlock()
127+
128+
var werr error
129+
if c.cmdExited.Load() {
142130
werr = c.cmdWaitErr
143131
} else {
144132
werrCh := make(chan error)
145133
go func() { werrCh <- c.cmd.Wait() }()
146134
select {
147135
case werr = <-werrCh:
148136
c.cmdWaitErr = werr
149-
c.cmdExited = true
137+
c.cmdExited.Store(true)
150138
case <-time.After(10 * time.Second):
151-
c.cmdMutex.Unlock()
152139
c.stderrMu.Lock()
153140
stderr := c.stderr.String()
154141
c.stderrMu.Unlock()
155-
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr)
142+
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr)
156143
}
157144
}
158-
c.cmdMutex.Unlock()
145+
159146
if werr == nil {
160-
return eof
147+
return err
161148
}
162149
c.stderrMu.Lock()
163150
stderr := c.stderr.String()
@@ -166,71 +153,88 @@ func (c *commandConn) onEOF(eof error) error {
166153
}
167154

168155
func ignorableCloseError(err error) bool {
169-
errS := err.Error()
170-
ss := []string{
171-
os.ErrClosed.Error(),
156+
return strings.Contains(err.Error(), os.ErrClosed.Error())
157+
}
158+
159+
func (c *commandConn) Read(p []byte) (int, error) {
160+
n, err := c.stdout.Read(p)
161+
// check after the call to Read, since
162+
// it is blocking, and while waiting on it
163+
// Close might get called
164+
if c.closing.Load() {
165+
// If we're currently closing the connection
166+
// we don't want to call onEOF, but we do want
167+
// to return an io.EOF
168+
return 0, io.EOF
172169
}
173-
for _, s := range ss {
174-
if strings.Contains(errS, s) {
175-
return true
176-
}
170+
171+
return n, c.handleEOF(err)
172+
}
173+
174+
func (c *commandConn) Write(p []byte) (int, error) {
175+
n, err := c.stdin.Write(p)
176+
// check after the call to Write, since
177+
// it is blocking, and while waiting on it
178+
// Close might get called
179+
if c.closing.Load() {
180+
// If we're currently closing the connection
181+
// we don't want to call onEOF, but we do want
182+
// to return an io.EOF
183+
return 0, io.EOF
177184
}
178-
return false
185+
186+
return n, c.handleEOF(err)
179187
}
180188

189+
// CloseRead allows commandConn to implement halfCloser
181190
func (c *commandConn) CloseRead() error {
182191
// NOTE: maybe already closed here
183192
if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
184-
logrus.Warnf("commandConn.CloseRead: %v", err)
193+
return err
185194
}
186-
c.stdioClosedMu.Lock()
187-
c.stdoutClosed = true
188-
c.stdioClosedMu.Unlock()
189-
if err := c.killIfStdioClosed(); err != nil {
190-
logrus.Warnf("commandConn.CloseRead: %v", err)
191-
}
192-
return nil
193-
}
195+
c.stdoutClosed.Store(true)
194196

195-
func (c *commandConn) Read(p []byte) (int, error) {
196-
n, err := c.stdout.Read(p)
197-
if err == io.EOF {
198-
err = c.onEOF(err)
197+
if c.stdinClosed.Load() {
198+
c.kill()
199199
}
200-
return n, err
200+
201+
return nil
201202
}
202203

204+
// CloseWrite allows commandConn to implement halfCloser
203205
func (c *commandConn) CloseWrite() error {
204206
// NOTE: maybe already closed here
205207
if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
206-
logrus.Warnf("commandConn.CloseWrite: %v", err)
207-
}
208-
c.stdioClosedMu.Lock()
209-
c.stdinClosed = true
210-
c.stdioClosedMu.Unlock()
211-
if err := c.killIfStdioClosed(); err != nil {
212-
logrus.Warnf("commandConn.CloseWrite: %v", err)
208+
return err
213209
}
214-
return nil
215-
}
210+
c.stdinClosed.Store(true)
216211

217-
func (c *commandConn) Write(p []byte) (int, error) {
218-
n, err := c.stdin.Write(p)
219-
if err == io.EOF {
220-
err = c.onEOF(err)
212+
if c.stdoutClosed.Load() {
213+
c.kill()
221214
}
222-
return n, err
215+
return nil
223216
}
224217

218+
// Close is the net.Conn func that gets called
219+
// by the transport when a dial is cancelled
220+
// due to it's context timing out. Any blocked
221+
// Read or Write calls will be unblocked and
222+
// return errors. It will block until the underlying
223+
// command has terminated.
225224
func (c *commandConn) Close() error {
226-
var err error
227-
if err = c.CloseRead(); err != nil {
225+
c.closing.Store(true)
226+
defer c.closing.Store(false)
227+
228+
if err := c.CloseRead(); err != nil {
228229
logrus.Warnf("commandConn.Close: CloseRead: %v", err)
230+
return err
229231
}
230-
if err = c.CloseWrite(); err != nil {
232+
if err := c.CloseWrite(); err != nil {
231233
logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
234+
return err
232235
}
233-
return err
236+
237+
return nil
234238
}
235239

236240
func (c *commandConn) LocalAddr() net.Addr {

0 commit comments

Comments
 (0)