@@ -23,6 +23,7 @@ import (
2323 "runtime"
2424 "strings"
2525 "sync"
26+ "sync/atomic"
2627 "syscall"
2728 "time"
2829
@@ -64,100 +65,86 @@ func New(_ context.Context, cmd string, args ...string) (net.Conn, error) {
6465
6566// commandConn implements net.Conn
6667type commandConn struct {
67- cmd * exec. Cmd
68- cmdExited bool
69- cmdWaitErr error
70- cmdMutex sync. Mutex
71- stdin io.WriteCloser
72- stdout io.ReadCloser
73- stderrMu sync.Mutex
74- stderr bytes.Buffer
75- stdioClosedMu sync. Mutex // for stdinClosed and stdoutClosed
76- stdinClosed bool
77- stdoutClosed bool
78- localAddr net.Addr
79- remoteAddr net.Addr
68+ cmdMutex sync. Mutex // for cmd, cmdWaitErr
69+ cmd * exec. Cmd
70+ cmdWaitErr error
71+ cmdExited atomic. Bool
72+ stdin io.WriteCloser
73+ stdout io.ReadCloser
74+ stderrMu sync.Mutex // for stderr
75+ stderr bytes.Buffer
76+ stdinClosed atomic. Bool
77+ stdoutClosed atomic. Bool
78+ closing atomic. Bool
79+ localAddr net.Addr
80+ remoteAddr net.Addr
8081}
8182
82- // killIfStdioClosed kills the cmd if both stdin and stdout are closed.
83- func (c * commandConn ) killIfStdioClosed () error {
84- c .stdioClosedMu .Lock ()
85- stdioClosed := c .stdoutClosed && c .stdinClosed
86- c .stdioClosedMu .Unlock ()
87- if ! stdioClosed {
88- return nil
83+ // kill terminates the process. On Windows it kills the process directly,
84+ // whereas on other platforms, a SIGTERM is sent, before forcefully terminating
85+ // the process after 3 seconds.
86+ func (c * commandConn ) kill () {
87+ if c .cmdExited .Load () {
88+ return
8989 }
90- return c .kill ()
91- }
92-
93- // killAndWait tries sending SIGTERM to the process before sending SIGKILL.
94- func killAndWait (cmd * exec.Cmd ) error {
90+ c .cmdMutex .Lock ()
9591 var werr error
9692 if runtime .GOOS != "windows" {
9793 werrCh := make (chan error )
98- go func () { werrCh <- cmd .Wait () }()
99- cmd .Process .Signal (syscall .SIGTERM )
94+ go func () { werrCh <- c . cmd .Wait () }()
95+ _ = c . cmd .Process .Signal (syscall .SIGTERM )
10096 select {
10197 case werr = <- werrCh :
10298 case <- time .After (3 * time .Second ):
103- cmd .Process .Kill ()
99+ _ = c . cmd .Process .Kill ()
104100 werr = <- werrCh
105101 }
106102 } else {
107- cmd .Process .Kill ()
108- werr = cmd .Wait ()
103+ _ = c . cmd .Process .Kill ()
104+ werr = c . cmd .Wait ()
109105 }
110- return werr
106+ c .cmdWaitErr = werr
107+ c .cmdMutex .Unlock ()
108+ c .cmdExited .Store (true )
111109}
112110
113- // kill returns nil if the command terminated, regardless to the exit status.
114- func (c * commandConn ) kill () error {
115- var werr error
116- c .cmdMutex .Lock ()
117- if c .cmdExited {
118- werr = c .cmdWaitErr
119- } else {
120- werr = killAndWait (c .cmd )
121- c .cmdWaitErr = werr
122- c .cmdExited = true
123- }
124- c .cmdMutex .Unlock ()
125- if werr == nil {
126- return nil
127- }
128- wExitErr , ok := werr .(* exec.ExitError )
129- if ok {
130- if wExitErr .ProcessState .Exited () {
131- return nil
132- }
111+ // handleEOF handles io.EOF errors while reading or writing from the underlying
112+ // command pipes.
113+ //
114+ // When we've received an EOF we expect that the command will
115+ // be terminated soon. As such, we call Wait() on the command
116+ // and return EOF or the error depending on whether the command
117+ // exited with an error.
118+ //
119+ // If Wait() does not return within 10s, an error is returned
120+ func (c * commandConn ) handleEOF (err error ) error {
121+ if err != io .EOF {
122+ return err
133123 }
134- return errors .Wrapf (werr , "commandconn: failed to wait" )
135- }
136124
137- func (c * commandConn ) onEOF (eof error ) error {
138- // when we got EOF, the command is going to be terminated
139- var werr error
140125 c .cmdMutex .Lock ()
141- if c .cmdExited {
126+ defer c .cmdMutex .Unlock ()
127+
128+ var werr error
129+ if c .cmdExited .Load () {
142130 werr = c .cmdWaitErr
143131 } else {
144132 werrCh := make (chan error )
145133 go func () { werrCh <- c .cmd .Wait () }()
146134 select {
147135 case werr = <- werrCh :
148136 c .cmdWaitErr = werr
149- c .cmdExited = true
137+ c .cmdExited . Store ( true )
150138 case <- time .After (10 * time .Second ):
151- c .cmdMutex .Unlock ()
152139 c .stderrMu .Lock ()
153140 stderr := c .stderr .String ()
154141 c .stderrMu .Unlock ()
155- return errors .Errorf ("command %v did not exit after %v: stderr=%q" , c .cmd .Args , eof , stderr )
142+ return errors .Errorf ("command %v did not exit after %v: stderr=%q" , c .cmd .Args , err , stderr )
156143 }
157144 }
158- c . cmdMutex . Unlock ()
145+
159146 if werr == nil {
160- return eof
147+ return err
161148 }
162149 c .stderrMu .Lock ()
163150 stderr := c .stderr .String ()
@@ -166,71 +153,88 @@ func (c *commandConn) onEOF(eof error) error {
166153}
167154
168155func ignorableCloseError (err error ) bool {
169- errS := err .Error ()
170- ss := []string {
171- os .ErrClosed .Error (),
156+ return strings .Contains (err .Error (), os .ErrClosed .Error ())
157+ }
158+
159+ func (c * commandConn ) Read (p []byte ) (int , error ) {
160+ n , err := c .stdout .Read (p )
161+ // check after the call to Read, since
162+ // it is blocking, and while waiting on it
163+ // Close might get called
164+ if c .closing .Load () {
165+ // If we're currently closing the connection
166+ // we don't want to call onEOF, but we do want
167+ // to return an io.EOF
168+ return 0 , io .EOF
172169 }
173- for _ , s := range ss {
174- if strings .Contains (errS , s ) {
175- return true
176- }
170+
171+ return n , c .handleEOF (err )
172+ }
173+
174+ func (c * commandConn ) Write (p []byte ) (int , error ) {
175+ n , err := c .stdin .Write (p )
176+ // check after the call to Write, since
177+ // it is blocking, and while waiting on it
178+ // Close might get called
179+ if c .closing .Load () {
180+ // If we're currently closing the connection
181+ // we don't want to call onEOF, but we do want
182+ // to return an io.EOF
183+ return 0 , io .EOF
177184 }
178- return false
185+
186+ return n , c .handleEOF (err )
179187}
180188
189+ // CloseRead allows commandConn to implement halfCloser
181190func (c * commandConn ) CloseRead () error {
182191 // NOTE: maybe already closed here
183192 if err := c .stdout .Close (); err != nil && ! ignorableCloseError (err ) {
184- logrus . Warnf ( "commandConn.CloseRead: %v" , err )
193+ return err
185194 }
186- c .stdioClosedMu .Lock ()
187- c .stdoutClosed = true
188- c .stdioClosedMu .Unlock ()
189- if err := c .killIfStdioClosed (); err != nil {
190- logrus .Warnf ("commandConn.CloseRead: %v" , err )
191- }
192- return nil
193- }
195+ c .stdoutClosed .Store (true )
194196
195- func (c * commandConn ) Read (p []byte ) (int , error ) {
196- n , err := c .stdout .Read (p )
197- if err == io .EOF {
198- err = c .onEOF (err )
197+ if c .stdinClosed .Load () {
198+ c .kill ()
199199 }
200- return n , err
200+
201+ return nil
201202}
202203
204+ // CloseWrite allows commandConn to implement halfCloser
203205func (c * commandConn ) CloseWrite () error {
204206 // NOTE: maybe already closed here
205207 if err := c .stdin .Close (); err != nil && ! ignorableCloseError (err ) {
206- logrus .Warnf ("commandConn.CloseWrite: %v" , err )
207- }
208- c .stdioClosedMu .Lock ()
209- c .stdinClosed = true
210- c .stdioClosedMu .Unlock ()
211- if err := c .killIfStdioClosed (); err != nil {
212- logrus .Warnf ("commandConn.CloseWrite: %v" , err )
208+ return err
213209 }
214- return nil
215- }
210+ c .stdinClosed .Store (true )
216211
217- func (c * commandConn ) Write (p []byte ) (int , error ) {
218- n , err := c .stdin .Write (p )
219- if err == io .EOF {
220- err = c .onEOF (err )
212+ if c .stdoutClosed .Load () {
213+ c .kill ()
221214 }
222- return n , err
215+ return nil
223216}
224217
218+ // Close is the net.Conn func that gets called
219+ // by the transport when a dial is cancelled
220+ // due to it's context timing out. Any blocked
221+ // Read or Write calls will be unblocked and
222+ // return errors. It will block until the underlying
223+ // command has terminated.
225224func (c * commandConn ) Close () error {
226- var err error
227- if err = c .CloseRead (); err != nil {
225+ c .closing .Store (true )
226+ defer c .closing .Store (false )
227+
228+ if err := c .CloseRead (); err != nil {
228229 logrus .Warnf ("commandConn.Close: CloseRead: %v" , err )
230+ return err
229231 }
230- if err = c .CloseWrite (); err != nil {
232+ if err : = c .CloseWrite (); err != nil {
231233 logrus .Warnf ("commandConn.Close: CloseWrite: %v" , err )
234+ return err
232235 }
233- return err
236+
237+ return nil
234238}
235239
236240func (c * commandConn ) LocalAddr () net.Addr {
0 commit comments