From 6508f13644e88df5bb36e7912d63744f5a6b10b4 Mon Sep 17 00:00:00 2001 From: Jeremiah Senkpiel Date: Tue, 4 May 2021 16:01:35 -0700 Subject: [PATCH 1/4] feat: unstable HttpClient Config This adds an `unstable-config` feature, with a new `Config` struct, which can be used to configure any `HttpClient` which implements support for it. Currently it supports two features - the most important and most generally supported: - `timeout` (`Duration`) - `no_delay` (`bool`) Implementations are provided for async-h1, isahc, and hyper (partial, no `no_delay` support due to the tls connector). No serious attempt has been made to add this to the wasm client at this point, since I don't understand well how to even build the wasm client or if it even works anymore with the state of rust wasm web build tools. --- Cargo.toml | 5 +++- src/config.rs | 47 +++++++++++++++++++++++++++++++ src/h1/mod.rs | 69 +++++++++++++++++++++++++++++++++++++++++---- src/h1/tcp.rs | 19 +++++++++++-- src/h1/tls.rs | 19 +++++++++++-- src/hyper.rs | 78 +++++++++++++++++++++++++++++++++++++++++++++------ src/isahc.rs | 68 +++++++++++++++++++++++++++++++++++++++++--- src/lib.rs | 44 ++++++++++++++++++++++++++++- 8 files changed, 322 insertions(+), 27 deletions(-) create mode 100644 src/config.rs diff --git a/Cargo.toml b/Cargo.toml index 0040a77..a625fd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,11 +27,13 @@ h1_client = ["async-h1", "async-std", "deadpool", "futures"] native_client = ["curl_client", "wasm_client"] curl_client = ["isahc", "async-std"] wasm_client = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-futures", "futures"] -hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util"] +hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util", "tokio"] native-tls = ["async-native-tls"] rustls = ["async-tls"] +unstable-config = [] + [dependencies] async-trait = "0.1.37" dashmap = "4.0.2" @@ -53,6 +55,7 @@ async-tls = { version = "0.10.0", optional = true } hyper = { version = "0.13.6", features = ["tcp"], optional = true } hyper-tls = { version = "0.4.3", optional = true } futures-util = { version = "0.3.5", features = ["io"], optional = true } +tokio = { version = "0.2", features = ["time"], optional = true } # curl_client [target.'cfg(not(target_arch = "wasm32"))'.dependencies] diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..538337f --- /dev/null +++ b/src/config.rs @@ -0,0 +1,47 @@ +//! Configuration for `HttpClient`s. + +use std::time::Duration; + +/// Configuration for `HttpClient`s. +#[non_exhaustive] +#[derive(Clone, Debug)] +pub struct Config { + /// TCP `NO_DELAY`. + /// + /// Default: `false`. + pub no_delay: bool, + /// Connection timeout duration. + /// + /// Default: `Some(Duration::from_secs(60))`. + pub timeout: Option, +} + +impl Config { + /// Construct new empty config. + pub fn new() -> Self { + Self { + no_delay: false, + timeout: Some(Duration::from_secs(60)), + } + } +} + +impl Default for Config { + fn default() -> Self { + Self::new() + } +} + +impl Config { + /// Set TCP `NO_DELAY`. + pub fn set_no_delay(mut self, no_delay: bool) -> Self { + self.no_delay = no_delay; + self + } + + /// Set connection timeout duration. + pub fn set_timeout(mut self, timeout: Option) -> Self { + self.timeout = timeout; + self + } +} diff --git a/src/h1/mod.rs b/src/h1/mod.rs index a94a204..c8b21af 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -1,5 +1,8 @@ //! http-client implementation for async-h1, with connecton pooling ("Keep-Alive"). +#[cfg(feature = "unstable-config")] +use std::convert::{Infallible, TryFrom}; + use std::fmt::Debug; use std::net::SocketAddr; @@ -17,6 +20,8 @@ cfg_if::cfg_if! { } } +use crate::Config; + use super::{async_trait, Error, HttpClient, Request, Response}; mod tcp; @@ -40,6 +45,7 @@ pub struct H1Client { #[cfg(any(feature = "native-tls", feature = "rustls"))] https_pools: HttpsPool, max_concurrent_connections: usize, + config: Config, } impl Debug for H1Client { @@ -75,6 +81,7 @@ impl Debug for H1Client { .collect::>(), ) .field("https_pools", &https_pools) + .field("config", &self.config) .field( "max_concurrent_connections", &self.max_concurrent_connections, @@ -97,6 +104,7 @@ impl H1Client { #[cfg(any(feature = "native-tls", feature = "rustls"))] https_pools: DashMap::new(), max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS, + config: Config::default(), } } @@ -107,6 +115,7 @@ impl H1Client { #[cfg(any(feature = "native-tls", feature = "rustls"))] https_pools: DashMap::new(), max_concurrent_connections: max, + config: Config::default(), } } } @@ -152,7 +161,7 @@ impl HttpClient for H1Client { let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) { pool_ref } else { - let manager = TcpConnection::new(addr); + let manager = TcpConnection::new(addr, self.config.clone()); let pool = Pool::::new( manager, self.max_concurrent_connections, @@ -168,19 +177,28 @@ impl HttpClient for H1Client { let stream = match pool.get().await { Ok(s) => s, Err(_) if has_another_addr => continue, - Err(e) => return Err(Error::from_str(400, e.to_string()))?, + Err(e) => return Err(Error::from_str(400, e.to_string())), }; req.set_peer_addr(stream.peer_addr().ok()); req.set_local_addr(stream.local_addr().ok()); - return client::connect(TcpConnWrapper::new(stream), req).await; + + let tcp_conn = client::connect(TcpConnWrapper::new(stream), req); + #[cfg(feature = "unstable-config")] + return if let Some(timeout) = self.config.timeout { + async_std::future::timeout(timeout, tcp_conn).await? + } else { + tcp_conn.await + }; + #[cfg(not(feature = "unstable-config"))] + return tcp_conn.await; } #[cfg(any(feature = "native-tls", feature = "rustls"))] "https" => { let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) { pool_ref } else { - let manager = TlsConnection::new(host.clone(), addr); + let manager = TlsConnection::new(host.clone(), addr, self.config.clone()); let pool = Pool::, Error>::new( manager, self.max_concurrent_connections, @@ -196,13 +214,21 @@ impl HttpClient for H1Client { let stream = match pool.get().await { Ok(s) => s, Err(_) if has_another_addr => continue, - Err(e) => return Err(Error::from_str(400, e.to_string()))?, + Err(e) => return Err(Error::from_str(400, e.to_string())), }; req.set_peer_addr(stream.get_ref().peer_addr().ok()); req.set_local_addr(stream.get_ref().local_addr().ok()); - return client::connect(TlsConnWrapper::new(stream), req).await; + let tls_conn = client::connect(TlsConnWrapper::new(stream), req); + #[cfg(feature = "unstable-config")] + return if let Some(timeout) = self.config.timeout { + async_std::future::timeout(timeout, tls_conn).await? + } else { + tls_conn.await + }; + #[cfg(not(feature = "unstable-config"))] + return tls_conn.await; } _ => unreachable!(), } @@ -213,6 +239,37 @@ impl HttpClient for H1Client { "missing valid address", )) } + + #[cfg(feature = "unstable-config")] + /// Override the existing configuration with new configuration. + /// + /// Config options may not impact existing connections. + fn set_config(&mut self, config: Config) -> http_types::Result<()> { + self.config = config; + + Ok(()) + } + + #[cfg(feature = "unstable-config")] + /// Get the current configuration. + fn config(&self) -> &Config { + &self.config + } +} + +#[cfg(feature = "unstable-config")] +impl TryFrom for H1Client { + type Error = Infallible; + + fn try_from(config: Config) -> Result { + Ok(Self { + http_pools: DashMap::new(), + #[cfg(any(feature = "native-tls", feature = "rustls"))] + https_pools: DashMap::new(), + max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS, + config, + }) + } } #[cfg(test)] diff --git a/src/h1/tcp.rs b/src/h1/tcp.rs index d99e13d..5b6f075 100644 --- a/src/h1/tcp.rs +++ b/src/h1/tcp.rs @@ -8,13 +8,17 @@ use deadpool::managed::{Manager, Object, RecycleResult}; use futures::io::{AsyncRead, AsyncWrite}; use futures::task::{Context, Poll}; +use crate::Config; + #[derive(Clone, Debug)] pub(crate) struct TcpConnection { addr: SocketAddr, + config: Config, } + impl TcpConnection { - pub(crate) fn new(addr: SocketAddr) -> Self { - Self { addr } + pub(crate) fn new(addr: SocketAddr, config: Config) -> Self { + Self { addr, config } } } @@ -58,12 +62,21 @@ impl AsyncWrite for TcpConnWrapper { #[async_trait] impl Manager for TcpConnection { async fn create(&self) -> Result { - TcpStream::connect(self.addr).await + let tcp_stream = TcpStream::connect(self.addr).await?; + + #[cfg(feature = "unstable-config")] + tcp_stream.set_nodelay(self.config.no_delay)?; + + Ok(tcp_stream) } async fn recycle(&self, conn: &mut TcpStream) -> RecycleResult { let mut buf = [0; 4]; let mut cx = Context::from_waker(futures::task::noop_waker_ref()); + + #[cfg(feature = "unstable-config")] + conn.set_nodelay(self.config.no_delay)?; + match Pin::new(conn).poll_read(&mut cx, &mut buf) { Poll::Ready(Err(error)) => Err(error), Poll::Ready(Ok(bytes)) if bytes == 0 => Err(std::io::Error::new( diff --git a/src/h1/tls.rs b/src/h1/tls.rs index f7c714c..482a55c 100644 --- a/src/h1/tls.rs +++ b/src/h1/tls.rs @@ -16,16 +16,18 @@ cfg_if::cfg_if! { } } -use crate::Error; +use crate::{Config, Error}; #[derive(Clone, Debug)] pub(crate) struct TlsConnection { host: String, addr: SocketAddr, + config: Config, } + impl TlsConnection { - pub(crate) fn new(host: String, addr: SocketAddr) -> Self { - Self { host, addr } + pub(crate) fn new(host: String, addr: SocketAddr, config: Config) -> Self { + Self { host, addr, config } } } @@ -70,6 +72,10 @@ impl AsyncWrite for TlsConnWrapper { impl Manager, Error> for TlsConnection { async fn create(&self) -> Result, Error> { let raw_stream = async_std::net::TcpStream::connect(self.addr).await?; + + #[cfg(feature = "unstable-config")] + raw_stream.set_nodelay(self.config.no_delay)?; + let tls_stream = add_tls(&self.host, raw_stream).await?; Ok(tls_stream) } @@ -77,6 +83,12 @@ impl Manager, Error> for TlsConnection { async fn recycle(&self, conn: &mut TlsStream) -> RecycleResult { let mut buf = [0; 4]; let mut cx = Context::from_waker(futures::task::noop_waker_ref()); + + #[cfg(feature = "unstable-config")] + conn.get_ref() + .set_nodelay(self.config.no_delay) + .map_err(Error::from)?; + match Pin::new(conn).poll_read(&mut cx, &mut buf) { Poll::Ready(Err(error)) => Err(error), Poll::Ready(Ok(bytes)) if bytes == 0 => Err(std::io::Error::new( @@ -86,6 +98,7 @@ impl Manager, Error> for TlsConnection { _ => Ok(()), } .map_err(Error::from)?; + Ok(()) } } diff --git a/src/hyper.rs b/src/hyper.rs index 44f816b..2a27a2f 100644 --- a/src/hyper.rs +++ b/src/hyper.rs @@ -1,16 +1,22 @@ //! http-client implementation for reqwest -use super::{async_trait, Error, HttpClient, Request, Response}; +#[cfg(feature = "unstable-config")] +use std::convert::Infallible; +use std::convert::TryFrom; +use std::fmt::Debug; +use std::io; +use std::str::FromStr; + use futures_util::stream::TryStreamExt; use http_types::headers::{HeaderName, HeaderValue}; use http_types::StatusCode; use hyper::body::HttpBody; use hyper::client::connect::Connect; use hyper_tls::HttpsConnector; -use std::convert::TryFrom; -use std::fmt::Debug; -use std::io; -use std::str::FromStr; + +use crate::Config; + +use super::{async_trait, Error, HttpClient, Request, Response}; type HyperRequest = hyper::Request; @@ -27,14 +33,21 @@ impl HyperClientObject for h /// Hyper-based HTTP Client. #[derive(Debug)] -pub struct HyperClient(Box); +pub struct HyperClient { + client: Box, + config: Config, +} impl HyperClient { /// Create a new client instance. pub fn new() -> Self { let https = HttpsConnector::new(); let client = hyper::Client::builder().build(https); - Self(Box::new(client)) + + Self { + client: Box::new(client), + config: Config::default(), + } } /// Create from externally initialized and configured client. @@ -42,7 +55,10 @@ impl HyperClient { where C: Clone + Connect + Debug + Send + Sync + 'static, { - Self(Box::new(client)) + Self { + client: Box::new(client), + config: Config::default(), + } } } @@ -57,11 +73,55 @@ impl HttpClient for HyperClient { async fn send(&self, req: Request) -> Result { let req = HyperHttpRequest::try_from(req).await?.into_inner(); - let response = self.0.dyn_request(req).await?; + let conn_fut = self.client.dyn_request(req); + #[cfg(feature = "unstable-config")] + let response = if let Some(timeout) = self.config.timeout { + match tokio::time::timeout(timeout, conn_fut).await { + Err(_elapsed) => Err(Error::from_str(400, "Client timed out")), + Ok(Ok(try_res)) => Ok(try_res), + Ok(Err(e)) => Err(e.into()), + }? + } else { + conn_fut.await? + }; + + #[cfg(not(feature = "unstable-config"))] + let response = conn_fut.await?; let res = HttpTypesResponse::try_from(response).await?.into_inner(); Ok(res) } + + #[cfg(feature = "unstable-config")] + /// Override the existing configuration with new configuration. + /// + /// Config options may not impact existing connections. + fn set_config(&mut self, config: Config) -> http_types::Result<()> { + self.config = config; + + Ok(()) + } + + #[cfg(feature = "unstable-config")] + /// Get the current configuration. + fn config(&self) -> &Config { + &self.config + } +} + +#[cfg(feature = "unstable-config")] +impl TryFrom for HyperClient { + type Error = Infallible; + + fn try_from(config: Config) -> Result { + let connector = HttpsConnector::new(); + let builder = hyper::Client::builder(); + + Ok(Self { + client: Box::new(builder.build(connector)), + config, + }) + } } struct HyperHttpRequest(HyperRequest); diff --git a/src/isahc.rs b/src/isahc.rs index 63c6f56..e5724ea 100644 --- a/src/isahc.rs +++ b/src/isahc.rs @@ -1,13 +1,23 @@ //! http-client implementation for isahc -use super::{async_trait, Body, Error, HttpClient, Request, Response}; +#[cfg(feature = "unstable-config")] +use std::convert::TryFrom; use async_std::io::BufReader; +#[cfg(feature = "unstable-config")] +use isahc::config::Configurable; use isahc::{http, ResponseExt}; +use crate::Config; + +use super::{async_trait, Body, Error, HttpClient, Request, Response}; + /// Curl-based HTTP Client. #[derive(Debug)] -pub struct IsahcClient(isahc::HttpClient); +pub struct IsahcClient { + client: isahc::HttpClient, + config: Config, +} impl Default for IsahcClient { fn default() -> Self { @@ -23,7 +33,10 @@ impl IsahcClient { /// Create from externally initialized and configured client. pub fn from_client(client: isahc::HttpClient) -> Self { - Self(client) + Self { + client, + config: Config::default(), + } } } @@ -45,7 +58,7 @@ impl HttpClient for IsahcClient { }; let request = builder.body(body).unwrap(); - let res = self.0.send_async(request).await.map_err(Error::from)?; + let res = self.client.send_async(request).await.map_err(Error::from)?; let maybe_metrics = res.metrics().cloned(); let (parts, body) = res.into_parts(); let body = Body::from_reader(BufReader::new(body), None); @@ -61,6 +74,53 @@ impl HttpClient for IsahcClient { response.set_body(body); Ok(response) } + + #[cfg(feature = "unstable-config")] + /// Override the existing configuration with new configuration. + /// + /// Config options may not impact existing connections. + fn set_config(&mut self, config: Config) -> http_types::Result<()> { + let mut builder = isahc::HttpClient::builder(); + + if config.no_delay { + builder = builder.tcp_nodelay(); + } + if let Some(timeout) = config.timeout { + builder = builder.timeout(timeout); + } + + self.client = builder.build()?; + self.config = config; + + Ok(()) + } + + #[cfg(feature = "unstable-config")] + /// Get the current configuration. + fn config(&self) -> &Config { + &self.config + } +} + +#[cfg(feature = "unstable-config")] +impl TryFrom for IsahcClient { + type Error = isahc::Error; + + fn try_from(config: Config) -> Result { + let mut builder = isahc::HttpClient::builder(); + + if config.no_delay { + builder = builder.tcp_nodelay(); + } + if let Some(timeout) = config.timeout { + builder = builder.timeout(timeout); + } + + Ok(Self { + client: builder.build()?, + config, + }) + } } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 2e9b103..3e60c25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,13 @@ forbid(unsafe_code) )] +#[cfg(feature = "unstable-config")] +mod config; +#[cfg(feature = "unstable-config")] +pub use config::Config; +#[cfg(not(feature = "unstable-config"))] +type Config = (); + #[cfg_attr(feature = "docs", doc(cfg(curl_client)))] #[cfg(all(feature = "curl_client", not(target_arch = "wasm32")))] pub mod isahc; @@ -60,6 +67,31 @@ pub use http_types; pub trait HttpClient: std::fmt::Debug + Unpin + Send + Sync + 'static { /// Perform a request. async fn send(&self, req: Request) -> Result; + + #[cfg(feature = "unstable-config")] + /// Override the existing configuration with new configuration. + /// + /// Config options may not impact existing connections. + fn set_config(&mut self, _config: Config) -> http_types::Result<()> { + unimplemented!( + "{} has not implemented `HttpClient::set_config()`", + type_name_of(self) + ) + } + + #[cfg(feature = "unstable-config")] + /// Get the current configuration. + fn config(&self) -> &Config { + unimplemented!( + "{} has not implemented `HttpClient::config()`", + type_name_of(self) + ) + } +} + +#[cfg(feature = "unstable-config")] +fn type_name_of(_val: &T) -> &'static str { + std::any::type_name::() } /// The raw body of an http request or response. @@ -70,7 +102,17 @@ pub type Error = http_types::Error; #[async_trait] impl HttpClient for Box { - async fn send(&self, req: Request) -> Result { + async fn send(&self, req: Request) -> http_types::Result { self.as_ref().send(req).await } + + #[cfg(feature = "unstable-config")] + fn set_config(&mut self, config: Config) -> http_types::Result<()> { + self.as_mut().set_config(config) + } + + #[cfg(feature = "unstable-config")] + fn config(&self) -> &Config { + self.as_ref().config() + } } From 06249b8016499417df84a2c2a52e6f80cd447552 Mon Sep 17 00:00:00 2001 From: Jeremiah Senkpiel Date: Wed, 5 May 2021 12:30:08 -0700 Subject: [PATCH 2/4] feat: add `http_keep_alive` to Config Also renamed `no_delay` to `tcp_no_delay`. the options for configuring this in Isahc and Hyper are unfortunately not super clear. --- src/config.rs | 19 +++++++++++++++---- src/h1/mod.rs | 31 +++++++++++++++++++++++++++++++ src/h1/tcp.rs | 4 ++-- src/h1/tls.rs | 8 ++++---- src/hyper.rs | 14 +++++++++++++- src/isahc.rs | 10 ++++++++-- 6 files changed, 73 insertions(+), 13 deletions(-) diff --git a/src/config.rs b/src/config.rs index 538337f..c167e3f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,10 +6,14 @@ use std::time::Duration; #[non_exhaustive] #[derive(Clone, Debug)] pub struct Config { + /// HTTP/1.1 `keep-alive` (connection pooling). + /// + /// Default: `true`. + pub http_keep_alive: bool, /// TCP `NO_DELAY`. /// /// Default: `false`. - pub no_delay: bool, + pub tcp_no_delay: bool, /// Connection timeout duration. /// /// Default: `Some(Duration::from_secs(60))`. @@ -20,7 +24,8 @@ impl Config { /// Construct new empty config. pub fn new() -> Self { Self { - no_delay: false, + http_keep_alive: true, + tcp_no_delay: false, timeout: Some(Duration::from_secs(60)), } } @@ -33,9 +38,15 @@ impl Default for Config { } impl Config { + /// Set HTTP/1.1 `keep-alive` (connection pooling). + pub fn set_http_keep_alive(mut self, keep_alive: bool) -> Self { + self.http_keep_alive = keep_alive; + self + } + /// Set TCP `NO_DELAY`. - pub fn set_no_delay(mut self, no_delay: bool) -> Self { - self.no_delay = no_delay; + pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self { + self.tcp_no_delay = no_delay; self } diff --git a/src/h1/mod.rs b/src/h1/mod.rs index c8b21af..1ade984 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -156,6 +156,37 @@ impl HttpClient for H1Client { for (idx, addr) in addrs.into_iter().enumerate() { let has_another_addr = idx != max_addrs_idx; + #[cfg(feature = "unstable-config")] + if !self.config.http_keep_alive { + match scheme { + "http" => { + let stream = async_std::net::TcpStream::connect(addr).await?; + req.set_peer_addr(stream.peer_addr().ok()); + req.set_local_addr(stream.local_addr().ok()); + let tcp_conn = client::connect(stream, req); + return if let Some(timeout) = self.config.timeout { + async_std::future::timeout(timeout, tcp_conn).await? + } else { + tcp_conn.await + }; + } + #[cfg(any(feature = "native-tls", feature = "rustls"))] + "https" => { + let raw_stream = async_std::net::TcpStream::connect(addr).await?; + req.set_peer_addr(raw_stream.peer_addr().ok()); + req.set_local_addr(raw_stream.local_addr().ok()); + let tls_stream = tls::add_tls(&host, raw_stream).await?; + let tsl_conn = client::connect(tls_stream, req); + return if let Some(timeout) = self.config.timeout { + async_std::future::timeout(timeout, tsl_conn).await? + } else { + tsl_conn.await + }; + } + _ => unreachable!(), + } + } + match scheme { "http" => { let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) { diff --git a/src/h1/tcp.rs b/src/h1/tcp.rs index 5b6f075..1ccfe06 100644 --- a/src/h1/tcp.rs +++ b/src/h1/tcp.rs @@ -65,7 +65,7 @@ impl Manager for TcpConnection { let tcp_stream = TcpStream::connect(self.addr).await?; #[cfg(feature = "unstable-config")] - tcp_stream.set_nodelay(self.config.no_delay)?; + tcp_stream.set_nodelay(self.config.tcp_no_delay)?; Ok(tcp_stream) } @@ -75,7 +75,7 @@ impl Manager for TcpConnection { let mut cx = Context::from_waker(futures::task::noop_waker_ref()); #[cfg(feature = "unstable-config")] - conn.set_nodelay(self.config.no_delay)?; + conn.set_nodelay(self.config.tcp_no_delay)?; match Pin::new(conn).poll_read(&mut cx, &mut buf) { Poll::Ready(Err(error)) => Err(error), diff --git a/src/h1/tls.rs b/src/h1/tls.rs index 482a55c..a994269 100644 --- a/src/h1/tls.rs +++ b/src/h1/tls.rs @@ -74,7 +74,7 @@ impl Manager, Error> for TlsConnection { let raw_stream = async_std::net::TcpStream::connect(self.addr).await?; #[cfg(feature = "unstable-config")] - raw_stream.set_nodelay(self.config.no_delay)?; + raw_stream.set_nodelay(self.config.tcp_no_delay)?; let tls_stream = add_tls(&self.host, raw_stream).await?; Ok(tls_stream) @@ -86,7 +86,7 @@ impl Manager, Error> for TlsConnection { #[cfg(feature = "unstable-config")] conn.get_ref() - .set_nodelay(self.config.no_delay) + .set_nodelay(self.config.tcp_no_delay) .map_err(Error::from)?; match Pin::new(conn).poll_read(&mut cx, &mut buf) { @@ -105,12 +105,12 @@ impl Manager, Error> for TlsConnection { cfg_if::cfg_if! { if #[cfg(feature = "rustls")] { - async fn add_tls(host: &str, stream: TcpStream) -> Result, std::io::Error> { + pub(crate) async fn add_tls(host: &str, stream: TcpStream) -> Result, std::io::Error> { let connector = async_tls::TlsConnector::default(); connector.connect(host, stream).await } } else if #[cfg(feature = "native-tls")] { - async fn add_tls( + pub(crate) async fn add_tls( host: &str, stream: TcpStream, ) -> Result, async_native_tls::Error> { diff --git a/src/hyper.rs b/src/hyper.rs index 2a27a2f..02cb416 100644 --- a/src/hyper.rs +++ b/src/hyper.rs @@ -97,6 +97,14 @@ impl HttpClient for HyperClient { /// /// Config options may not impact existing connections. fn set_config(&mut self, config: Config) -> http_types::Result<()> { + let connector = HttpsConnector::new(); + let mut builder = hyper::Client::builder(); + + if !config.http_keep_alive { + builder.pool_max_idle_per_host(1); + } + + self.client = Box::new(builder.build(connector)); self.config = config; Ok(()) @@ -115,7 +123,11 @@ impl TryFrom for HyperClient { fn try_from(config: Config) -> Result { let connector = HttpsConnector::new(); - let builder = hyper::Client::builder(); + let mut builder = hyper::Client::builder(); + + if !config.http_keep_alive { + builder.pool_max_idle_per_host(1); + } Ok(Self { client: Box::new(builder.build(connector)), diff --git a/src/isahc.rs b/src/isahc.rs index e5724ea..c5ff159 100644 --- a/src/isahc.rs +++ b/src/isahc.rs @@ -82,7 +82,10 @@ impl HttpClient for IsahcClient { fn set_config(&mut self, config: Config) -> http_types::Result<()> { let mut builder = isahc::HttpClient::builder(); - if config.no_delay { + if !config.http_keep_alive { + builder = builder.connection_cache_size(0); + } + if config.tcp_no_delay { builder = builder.tcp_nodelay(); } if let Some(timeout) = config.timeout { @@ -109,7 +112,10 @@ impl TryFrom for IsahcClient { fn try_from(config: Config) -> Result { let mut builder = isahc::HttpClient::builder(); - if config.no_delay { + if !config.http_keep_alive { + builder = builder.connection_cache_size(0); + } + if config.tcp_no_delay { builder = builder.tcp_nodelay(); } if let Some(timeout) = config.timeout { From 2740e2641ca9cb415fb80adeeccedcab06c0e876 Mon Sep 17 00:00:00 2001 From: Jeremiah Senkpiel Date: Wed, 5 May 2021 15:44:12 -0700 Subject: [PATCH 3/4] feat: add `tls_config` to Config Only supports the h1 client, but has two different options, one each for native-tls and rustls. --- Cargo.toml | 3 ++- src/config.rs | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++- src/h1/mod.rs | 4 ++-- src/h1/tcp.rs | 4 ++-- src/h1/tls.rs | 26 +++++++++++++++++++----- 5 files changed, 81 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a625fd8..eb91365 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ wasm_client = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-futures", "fut hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util", "tokio"] native-tls = ["async-native-tls"] -rustls = ["async-tls"] +rustls = ["async-tls", "rustls_crate"] unstable-config = [] @@ -50,6 +50,7 @@ futures = { version = "0.3.8", optional = true } # h1_client_rustls async-tls = { version = "0.10.0", optional = true } +rustls_crate = { version = "0.18", optional = true, package = "rustls" } # hyper_client hyper = { version = "0.13.6", features = ["tcp"], optional = true } diff --git a/src/config.rs b/src/config.rs index c167e3f..acf9582 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,10 +1,11 @@ //! Configuration for `HttpClient`s. +use std::fmt::Debug; use std::time::Duration; /// Configuration for `HttpClient`s. #[non_exhaustive] -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Config { /// HTTP/1.1 `keep-alive` (connection pooling). /// @@ -18,6 +19,37 @@ pub struct Config { /// /// Default: `Some(Duration::from_secs(60))`. pub timeout: Option, + /// TLS Configuration (Rustls) + #[cfg(all(feature = "h1_client", feature = "rustls"))] + pub tls_config: Option>, + /// TLS Configuration (Native TLS) + #[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))] + pub tls_config: Option>, +} + +impl Debug for Config { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut dbg_struct = f.debug_struct("Config"); + dbg_struct + .field("http_keep_alive", &self.http_keep_alive) + .field("tcp_no_delay", &self.tcp_no_delay) + .field("timeout", &self.timeout); + + #[cfg(all(feature = "h1_client", feature = "rustls"))] + { + if self.tls_config.is_some() { + dbg_struct.field("tls_config", &"Some(rustls::ClientConfig)"); + } else { + dbg_struct.field("tls_config", &"None"); + } + } + #[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))] + { + dbg_struct.field("tls_config", &self.tls_config); + } + + dbg_struct.finish() + } } impl Config { @@ -27,6 +59,8 @@ impl Config { http_keep_alive: true, tcp_no_delay: false, timeout: Some(Duration::from_secs(60)), + #[cfg(all(feature = "h1_client", any(feature = "rustls", feature = "native-tls")))] + tls_config: None, } } } @@ -55,4 +89,23 @@ impl Config { self.timeout = timeout; self } + + /// Set TLS Configuration (Rustls) + #[cfg(all(feature = "h1_client", feature = "rustls"))] + pub fn set_tls_config( + mut self, + tls_config: Option>, + ) -> Self { + self.tls_config = tls_config; + self + } + /// Set TLS Configuration (Native TLS) + #[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))] + pub fn set_tls_config( + mut self, + tls_config: Option>, + ) -> Self { + self.tls_config = tls_config; + self + } } diff --git a/src/h1/mod.rs b/src/h1/mod.rs index 1ade984..b3e43ba 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -81,11 +81,11 @@ impl Debug for H1Client { .collect::>(), ) .field("https_pools", &https_pools) - .field("config", &self.config) .field( "max_concurrent_connections", &self.max_concurrent_connections, ) + .field("config", &self.config) .finish() } } @@ -175,7 +175,7 @@ impl HttpClient for H1Client { let raw_stream = async_std::net::TcpStream::connect(addr).await?; req.set_peer_addr(raw_stream.peer_addr().ok()); req.set_local_addr(raw_stream.local_addr().ok()); - let tls_stream = tls::add_tls(&host, raw_stream).await?; + let tls_stream = tls::add_tls(&host, raw_stream, &self.config).await?; let tsl_conn = client::connect(tls_stream, req); return if let Some(timeout) = self.config.timeout { async_std::future::timeout(timeout, tsl_conn).await? diff --git a/src/h1/tcp.rs b/src/h1/tcp.rs index 1ccfe06..6d2c151 100644 --- a/src/h1/tcp.rs +++ b/src/h1/tcp.rs @@ -1,4 +1,3 @@ -use std::fmt::Debug; use std::net::SocketAddr; use std::pin::Pin; @@ -10,7 +9,8 @@ use futures::task::{Context, Poll}; use crate::Config; -#[derive(Clone, Debug)] +#[derive(Clone)] +#[cfg_attr(not(feature = "rustls"), derive(std::fmt::Debug))] pub(crate) struct TcpConnection { addr: SocketAddr, config: Config, diff --git a/src/h1/tls.rs b/src/h1/tls.rs index a994269..84682c9 100644 --- a/src/h1/tls.rs +++ b/src/h1/tls.rs @@ -1,4 +1,3 @@ -use std::fmt::Debug; use std::net::SocketAddr; use std::pin::Pin; @@ -18,7 +17,8 @@ cfg_if::cfg_if! { use crate::{Config, Error}; -#[derive(Clone, Debug)] +#[derive(Clone)] +#[cfg_attr(not(feature = "rustls"), derive(std::fmt::Debug))] pub(crate) struct TlsConnection { host: String, addr: SocketAddr, @@ -76,7 +76,7 @@ impl Manager, Error> for TlsConnection { #[cfg(feature = "unstable-config")] raw_stream.set_nodelay(self.config.tcp_no_delay)?; - let tls_stream = add_tls(&self.host, raw_stream).await?; + let tls_stream = add_tls(&self.host, raw_stream, &self.config).await?; Ok(tls_stream) } @@ -105,16 +105,32 @@ impl Manager, Error> for TlsConnection { cfg_if::cfg_if! { if #[cfg(feature = "rustls")] { - pub(crate) async fn add_tls(host: &str, stream: TcpStream) -> Result, std::io::Error> { + #[allow(unused_variables)] + pub(crate) async fn add_tls(host: &str, stream: TcpStream, config: &Config) -> Result, std::io::Error> { + #[cfg(all(feature = "h1_client", feature = "unstable-config"))] + let connector = if let Some(tls_config) = config.tls_config.as_ref().cloned() { + tls_config.into() + } else { + async_tls::TlsConnector::default() + }; + #[cfg(not(feature = "unstable-config"))] let connector = async_tls::TlsConnector::default(); + connector.connect(host, stream).await } } else if #[cfg(feature = "native-tls")] { + #[allow(unused_variables)] pub(crate) async fn add_tls( host: &str, stream: TcpStream, + config: &Config, ) -> Result, async_native_tls::Error> { - async_native_tls::connect(host, stream).await + #[cfg(feature = "unstable-config")] + let connector = config.tls_config.as_ref().cloned().unwrap_or_default(); + #[cfg(not(feature = "unstable-config"))] + let connector = async_native_tls::TlsConnector::new(); + + connector.connect(host, stream).await } } } From a268be7cfdf10f301952aef8e2922a94ff4a2a43 Mon Sep 17 00:00:00 2001 From: Jeremiah Senkpiel Date: Wed, 5 May 2021 15:49:28 -0700 Subject: [PATCH 4/4] refactor: share the H1Client Config via an Arc --- src/h1/mod.rs | 13 +++++++------ src/h1/tcp.rs | 5 +++-- src/h1/tls.rs | 5 +++-- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/h1/mod.rs b/src/h1/mod.rs index b3e43ba..155d100 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -5,6 +5,7 @@ use std::convert::{Infallible, TryFrom}; use std::fmt::Debug; use std::net::SocketAddr; +use std::sync::Arc; use async_h1::client; use async_std::net::TcpStream; @@ -45,7 +46,7 @@ pub struct H1Client { #[cfg(any(feature = "native-tls", feature = "rustls"))] https_pools: HttpsPool, max_concurrent_connections: usize, - config: Config, + config: Arc, } impl Debug for H1Client { @@ -104,7 +105,7 @@ impl H1Client { #[cfg(any(feature = "native-tls", feature = "rustls"))] https_pools: DashMap::new(), max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS, - config: Config::default(), + config: Arc::new(Config::default()), } } @@ -115,7 +116,7 @@ impl H1Client { #[cfg(any(feature = "native-tls", feature = "rustls"))] https_pools: DashMap::new(), max_concurrent_connections: max, - config: Config::default(), + config: Arc::new(Config::default()), } } } @@ -276,7 +277,7 @@ impl HttpClient for H1Client { /// /// Config options may not impact existing connections. fn set_config(&mut self, config: Config) -> http_types::Result<()> { - self.config = config; + self.config = Arc::new(config); Ok(()) } @@ -284,7 +285,7 @@ impl HttpClient for H1Client { #[cfg(feature = "unstable-config")] /// Get the current configuration. fn config(&self) -> &Config { - &self.config + &*self.config } } @@ -298,7 +299,7 @@ impl TryFrom for H1Client { #[cfg(any(feature = "native-tls", feature = "rustls"))] https_pools: DashMap::new(), max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS, - config, + config: Arc::new(config), }) } } diff --git a/src/h1/tcp.rs b/src/h1/tcp.rs index 6d2c151..6b855fd 100644 --- a/src/h1/tcp.rs +++ b/src/h1/tcp.rs @@ -1,5 +1,6 @@ use std::net::SocketAddr; use std::pin::Pin; +use std::sync::Arc; use async_std::net::TcpStream; use async_trait::async_trait; @@ -13,11 +14,11 @@ use crate::Config; #[cfg_attr(not(feature = "rustls"), derive(std::fmt::Debug))] pub(crate) struct TcpConnection { addr: SocketAddr, - config: Config, + config: Arc, } impl TcpConnection { - pub(crate) fn new(addr: SocketAddr, config: Config) -> Self { + pub(crate) fn new(addr: SocketAddr, config: Arc) -> Self { Self { addr, config } } } diff --git a/src/h1/tls.rs b/src/h1/tls.rs index 84682c9..796936c 100644 --- a/src/h1/tls.rs +++ b/src/h1/tls.rs @@ -1,5 +1,6 @@ use std::net::SocketAddr; use std::pin::Pin; +use std::sync::Arc; use async_std::net::TcpStream; use async_trait::async_trait; @@ -22,11 +23,11 @@ use crate::{Config, Error}; pub(crate) struct TlsConnection { host: String, addr: SocketAddr, - config: Config, + config: Arc, } impl TlsConnection { - pub(crate) fn new(host: String, addr: SocketAddr, config: Config) -> Self { + pub(crate) fn new(host: String, addr: SocketAddr, config: Arc) -> Self { Self { host, addr, config } } }