Skip to content

Commit 57f1670

Browse files
akosyakovcsweichel
authored andcommitted
[tunnel] close the tunnel if at least one side is dropped
1 parent 515316a commit 57f1670

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

components/local-app/pkg/bastion/bastion.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import (
2626
"github.com/kevinburke/ssh_config"
2727
"github.com/prometheus/common/log"
2828
"golang.org/x/crypto/ssh"
29-
"golang.org/x/sync/errgroup"
3029
"google.golang.org/grpc"
3130
"google.golang.org/protobuf/proto"
3231

@@ -449,7 +448,9 @@ func (b *Bastion) establishTunnel(ctx context.Context, ws *Workspace, logprefix
449448
logrus.WithError(err).WithField("workspace", ws.WorkspaceID).Warn(logprefix + ": failed to accept connection")
450449
continue
451450
}
451+
logrus.WithField("workspace", ws.WorkspaceID).Debug(logprefix + ": accepted new connection")
452452
go func() {
453+
defer logrus.WithField("workspace", ws.WorkspaceID).Debug(logprefix + ": connection closed")
453454
defer conn.Close()
454455

455456
clientCh := make(chan *TunnelClient, 1)
@@ -476,16 +477,17 @@ func (b *Bastion) establishTunnel(ctx context.Context, ws *Workspace, logprefix
476477
}
477478
defer sshChan.Close()
478479
go ssh.DiscardRequests(reqs)
479-
eg, _ := errgroup.WithContext(listenerCtx)
480-
eg.Go(func() error {
480+
481+
ctx, cancel := context.WithCancel(listenerCtx)
482+
go func() {
481483
_, _ = io.Copy(sshChan, conn)
482-
return nil
483-
})
484-
eg.Go(func() error {
484+
cancel()
485+
}()
486+
go func() {
485487
_, _ = io.Copy(conn, sshChan)
486-
return nil
487-
})
488-
eg.Wait()
488+
cancel()
489+
}()
490+
<-ctx.Done()
489491
}()
490492
}
491493
}()

components/supervisor/pkg/supervisor/supervisor.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,8 @@ func tunnelOverSSH(ctx context.Context, tunneled *ports.TunneledPortsService, ne
731731
newCh.Reject(ssh.Prohibited, err.Error())
732732
return
733733
}
734+
log.Debug("tunnel: accepted new connection")
735+
defer log.Debug("tunnel: connection closed")
734736
defer tunnel.Close()
735737

736738
sshChan, reqs, err := newCh.Accept()
@@ -740,17 +742,16 @@ func tunnelOverSSH(ctx context.Context, tunneled *ports.TunneledPortsService, ne
740742
}
741743
defer sshChan.Close()
742744
go ssh.DiscardRequests(reqs)
743-
var wg sync.WaitGroup
744-
wg.Add(2)
745+
ctx, cancel := context.WithCancel(ctx)
745746
go func() {
746747
_, _ = io.Copy(sshChan, tunnel)
747-
wg.Done()
748+
cancel()
748749
}()
749750
go func() {
750751
_, _ = io.Copy(tunnel, sshChan)
751-
wg.Done()
752+
cancel()
752753
}()
753-
wg.Wait()
754+
<-ctx.Done()
754755
}
755756

756757
func startSSHServer(ctx context.Context, cfg *Config, wg *sync.WaitGroup) {

0 commit comments

Comments
 (0)