diff --git a/src/builder.rs b/src/builder.rs index 5c2b474f4..70bd85050 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1,5 +1,6 @@ use super::{Error, Session}; use std::borrow::Cow; +use std::path::{Path, PathBuf}; use std::process::Stdio; use tempfile::Builder; use tokio::io::AsyncReadExt; @@ -10,10 +11,11 @@ use tokio::process; pub struct SessionBuilder { user: Option, port: Option, - keyfile: Option, + keyfile: Option, connect_timeout: Option, server_alive_interval: Option, known_hosts_check: KnownHosts, + control_dir: Option, } impl Default for SessionBuilder { @@ -25,6 +27,7 @@ impl Default for SessionBuilder { connect_timeout: None, server_alive_interval: None, known_hosts_check: KnownHosts::Add, + control_dir: None, } } } @@ -49,7 +52,7 @@ impl SessionBuilder { /// Set the keyfile to use (`ssh -i`). /// /// Defaults to `None`. - pub fn keyfile(&mut self, p: impl AsRef) -> &mut Self { + pub fn keyfile(&mut self, p: impl AsRef) -> &mut Self { self.keyfile = Some(p.as_ref().to_path_buf()); self } @@ -81,6 +84,15 @@ impl SessionBuilder { self } + /// Set the directory in which the temporary directory containing the control socket will + /// be created. + /// + /// If not set, `./` will be used (the current directory). + pub fn control_directory(&mut self, p: impl AsRef) -> &mut Self { + self.control_dir = Some(p.as_ref().to_path_buf()); + self + } + /// Connect to the host at the given `host` over SSH. /// /// The format of `destination` is the same as the `destination` argument to `ssh`. It may be @@ -137,9 +149,12 @@ impl SessionBuilder { pub(crate) async fn just_connect>(&self, host: S) -> Result { let destination = host.as_ref(); + + let defaultdir = Path::new("./").to_path_buf(); + let socketdir = self.control_dir.as_ref().unwrap_or(&defaultdir); let dir = Builder::new() .prefix(".ssh-connection") - .tempdir_in("./") + .tempdir_in(socketdir) .map_err(Error::Master)?; let mut init = process::Command::new("ssh"); diff --git a/tests/openssh.rs b/tests/openssh.rs index a12299743..8179ecc6c 100644 --- a/tests/openssh.rs +++ b/tests/openssh.rs @@ -17,6 +17,24 @@ async fn it_connects() { session.close().await.unwrap(); } +#[tokio::test] +#[cfg_attr(not(ci), ignore)] +async fn control_dir() { + let dirname = std::path::Path::new("control-test"); + assert!(!dirname.exists()); + std::fs::create_dir(dirname).unwrap(); + let session = SessionBuilder::default() + .control_directory(&dirname) + .connect(&addr()) + .await + .unwrap(); + session.check().await.unwrap(); + let mut iter = std::fs::read_dir(&dirname).unwrap(); + assert!(iter.next().is_some()); + session.close().await.unwrap(); + std::fs::remove_dir(&dirname).unwrap(); +} + #[tokio::test] #[cfg_attr(not(ci), ignore)] async fn terminate_on_drop() {