Skip to content
This repository was archived by the owner on Jan 17, 2021. It is now read-only.

Add SSH master connection feature #116

Merged
merged 8 commits into from
Jun 28, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ var _ interface {
} = new(rootCmd)

type rootCmd struct {
skipSync bool
syncBack bool
printVersion bool
bindAddr string
sshFlags string
skipSync bool
syncBack bool
printVersion bool
noReuseConnection bool
bindAddr string
sshFlags string
}

func (c *rootCmd) Spec() cli.CommandSpec {
Expand All @@ -53,6 +54,7 @@ func (c *rootCmd) RegisterFlags(fl *flag.FlagSet) {
fl.BoolVar(&c.skipSync, "skipsync", false, "skip syncing local settings and extensions to remote host")
fl.BoolVar(&c.syncBack, "b", false, "sync extensions back on termination")
fl.BoolVar(&c.printVersion, "version", false, "print version information and exit")
fl.BoolVar(&c.noReuseConnection, "no-reuse-connection", false, "do not reuse SSH connection via control socket")
fl.StringVar(&c.bindAddr, "bind", "", "local bind address for ssh tunnel")
fl.StringVar(&c.sshFlags, "ssh-flags", "", "custom SSH flags")
}
Expand All @@ -76,10 +78,11 @@ func (c *rootCmd) Run(fl *flag.FlagSet) {
}

err := sshCode(host, dir, options{
skipSync: c.skipSync,
sshFlags: c.sshFlags,
bindAddr: c.bindAddr,
syncBack: c.syncBack,
skipSync: c.skipSync,
sshFlags: c.sshFlags,
bindAddr: c.bindAddr,
syncBack: c.syncBack,
noReuseConnection: c.noReuseConnection,
})

if err != nil {
Expand All @@ -101,7 +104,7 @@ Environment variables:
More info: https://github.com/cdr/sshcode

Arguments:
%vHOST is passed into the ssh command. Valid formats are '<ip-address>' or 'gcp:<instance-name>'.
%vHOST is passed into the ssh command. Valid formats are '<ip-address>' or 'gcp:<instance-name>'.
%vDIR is optional.`,
helpTab, vsCodeConfigDirEnv,
helpTab, vsCodeExtensionsDirEnv,
Expand Down
108 changes: 91 additions & 17 deletions sshcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ import (
)

const codeServerPath = "~/.cache/sshcode/sshcode-server"
const sshControlPath = "~/.ssh/control-%h-%p-%r"

type options struct {
skipSync bool
syncBack bool
noOpen bool
bindAddr string
remotePort string
sshFlags string
skipSync bool
syncBack bool
noOpen bool
noReuseConnection bool
bindAddr string
remotePort string
sshFlags string
}

func sshCode(host, dir string, o options) error {
Expand All @@ -53,6 +55,41 @@ func sshCode(host, dir string, o options) error {
return xerrors.Errorf("failed to find available remote port: %w", err)
}

// Start SSH master connection socket. This prevents multiple password prompts from appearing as authentication
// only happens on the initial connection.
var sshMasterCmd *exec.Cmd
if !o.noReuseConnection {
newSSHFlags := fmt.Sprintf(`%v -o "ControlPath=%v"`, o.sshFlags, sshControlPath)

// -MN means "start a master socket and don't open a session, just connect".
sshMasterCmdStr := fmt.Sprintf(`ssh %v -MN %v`, newSSHFlags, host)
sshMasterCmd = exec.Command("sh", "-c", sshMasterCmdStr)
sshMasterCmd.Stdin = os.Stdin
sshMasterCmd.Stdout = os.Stdout
sshMasterCmd.Stderr = os.Stderr
err = sshMasterCmd.Start()
if err != nil {
flog.Error("failed to start SSH master connection, disabling connection reuse feature: %v", err)
o.noReuseConnection = true
} else {
// Wait for master to be ready.
err = checkSSHMaster(newSSHFlags, host)
if err != nil {
flog.Error("SSH master failed to start in time, disabling connection reuse feature: %v", err)
o.noReuseConnection = true
if sshMasterCmd.Process != nil {
err = sshMasterCmd.Process.Kill()
if err != nil {
flog.Error("failed to kill SSH master connection, ignoring: %v", err)
}
}
} else {
sshMasterCmd.Stdin = nil
o.sshFlags = newSSHFlags
}
}
}

dlScript := downloadScript(codeServerPath)

// Downloads the latest code-server and allows it to be executed.
Expand Down Expand Up @@ -146,22 +183,39 @@ func sshCode(host, dir string, o options) error {
case <-ctx.Done():
case <-c:
}
flog.Info("exiting")

if !o.syncBack || o.skipSync {
flog.Info("shutting down")
return nil
}
if o.syncBack && !o.skipSync {
flog.Info("synchronizing VS Code back to local")

flog.Info("synchronizing VS Code back to local")
err = syncExtensions(o.sshFlags, host, true)
if err != nil {
flog.Error("failed to sync extensions back: %v", err)
}

err = syncExtensions(o.sshFlags, host, true)
if err != nil {
return xerrors.Errorf("failed to sync extensions back: %w", err)
err = syncUserSettings(o.sshFlags, host, true)
if err != nil {
flog.Error("failed to sync user settings settings back: %v", err)
}
}

err = syncUserSettings(o.sshFlags, host, true)
if err != nil {
return xerrors.Errorf("failed to sync user settings settings back: %w", err)
// Kill the master connection if we made one.
if !o.noReuseConnection {
// Try using the -O exit syntax first before killing the master.
sshCmdStr = fmt.Sprintf(`ssh %v -O exit %v`, o.sshFlags, host)
sshCmd = exec.Command("sh", "-c", sshCmdStr)
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
err = sshCmd.Run()
if err != nil {
flog.Error("failed to gracefully stop SSH master connection, killing: %v", err)
if sshMasterCmd.Process != nil {
err = sshMasterCmd.Process.Kill()
if err != nil {
flog.Error("failed to kill SSH master connection, ignoring: %v", err)
}
}
}
}

return nil
Expand Down Expand Up @@ -263,6 +317,26 @@ func randomPort() (string, error) {
return "", xerrors.Errorf("max number of tries exceeded: %d", maxTries)
}

// checkSSHMaster polls every second for 30 seconds to check if the SSH master
// is ready.
func checkSSHMaster(sshFlags string, host string) (err error) {
maxTries := 30
check := func() error {
sshCmdStr := fmt.Sprintf(`ssh %v -O check %v`, sshFlags, host)
sshCmd := exec.Command("sh", "-c", sshCmdStr)
return sshCmd.Run()
}

for i := 0; i < maxTries; i++ {
err = check()
if err == nil {
return nil
}
time.Sleep(time.Second)
}
return err
}

func syncUserSettings(sshFlags string, host string, back bool) error {
localConfDir, err := configDir()
if err != nil {
Expand Down