diff --git a/Cargo.toml b/Cargo.toml index 0040a77..eb91365 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,10 +27,12 @@ h1_client = ["async-h1", "async-std", "deadpool", "futures"] native_client = ["curl_client", "wasm_client"] curl_client = ["isahc", "async-std"] wasm_client = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-futures", "futures"] -hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util"] +hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util", "tokio"] native-tls = ["async-native-tls"] -rustls = ["async-tls"] +rustls = ["async-tls", "rustls_crate"] + +unstable-config = [] [dependencies] async-trait = "0.1.37" @@ -48,11 +50,13 @@ futures = { version = "0.3.8", optional = true } # h1_client_rustls async-tls = { version = "0.10.0", optional = true } +rustls_crate = { version = "0.18", optional = true, package = "rustls" } # hyper_client hyper = { version = "0.13.6", features = ["tcp"], optional = true } hyper-tls = { version = "0.4.3", optional = true } futures-util = { version = "0.3.5", features = ["io"], optional = true } +tokio = { version = "0.2", features = ["time"], optional = true } # curl_client [target.'cfg(not(target_arch = "wasm32"))'.dependencies] diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..acf9582 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,111 @@ +//! Configuration for `HttpClient`s. + +use std::fmt::Debug; +use std::time::Duration; + +/// Configuration for `HttpClient`s. +#[non_exhaustive] +#[derive(Clone)] +pub struct Config { + /// HTTP/1.1 `keep-alive` (connection pooling). + /// + /// Default: `true`. + pub http_keep_alive: bool, + /// TCP `NO_DELAY`. + /// + /// Default: `false`. + pub tcp_no_delay: bool, + /// Connection timeout duration. + /// + /// Default: `Some(Duration::from_secs(60))`. + pub timeout: Option, + /// TLS Configuration (Rustls) + #[cfg(all(feature = "h1_client", feature = "rustls"))] + pub tls_config: Option>, + /// TLS Configuration (Native TLS) + #[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))] + pub tls_config: Option>, +} + +impl Debug for Config { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut dbg_struct = f.debug_struct("Config"); + dbg_struct + .field("http_keep_alive", &self.http_keep_alive) + .field("tcp_no_delay", &self.tcp_no_delay) + .field("timeout", &self.timeout); + + #[cfg(all(feature = "h1_client", feature = "rustls"))] + { + if self.tls_config.is_some() { + dbg_struct.field("tls_config", &"Some(rustls::ClientConfig)"); + } else { + dbg_struct.field("tls_config", &"None"); + } + } + #[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))] + { + dbg_struct.field("tls_config", &self.tls_config); + } + + dbg_struct.finish() + } +} + +impl Config { + /// Construct new empty config. + pub fn new() -> Self { + Self { + http_keep_alive: true, + tcp_no_delay: false, + timeout: Some(Duration::from_secs(60)), + #[cfg(all(feature = "h1_client", any(feature = "rustls", feature = "native-tls")))] + tls_config: None, + } + } +} + +impl Default for Config { + fn default() -> Self { + Self::new() + } +} + +impl Config { + /// Set HTTP/1.1 `keep-alive` (connection pooling). + pub fn set_http_keep_alive(mut self, keep_alive: bool) -> Self { + self.http_keep_alive = keep_alive; + self + } + + /// Set TCP `NO_DELAY`. + pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self { + self.tcp_no_delay = no_delay; + self + } + + /// Set connection timeout duration. + pub fn set_timeout(mut self, timeout: Option) -> Self { + self.timeout = timeout; + self + } + + /// Set TLS Configuration (Rustls) + #[cfg(all(feature = "h1_client", feature = "rustls"))] + pub fn set_tls_config( + mut self, + tls_config: Option>, + ) -> Self { + self.tls_config = tls_config; + self + } + /// Set TLS Configuration (Native TLS) + #[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))] + pub fn set_tls_config( + mut self, + tls_config: Option>, + ) -> Self { + self.tls_config = tls_config; + self + } +} diff --git a/src/h1/mod.rs b/src/h1/mod.rs index a94a204..155d100 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -1,7 +1,11 @@ //! http-client implementation for async-h1, with connecton pooling ("Keep-Alive"). +#[cfg(feature = "unstable-config")] +use std::convert::{Infallible, TryFrom}; + use std::fmt::Debug; use std::net::SocketAddr; +use std::sync::Arc; use async_h1::client; use async_std::net::TcpStream; @@ -17,6 +21,8 @@ cfg_if::cfg_if! { } } +use crate::Config; + use super::{async_trait, Error, HttpClient, Request, Response}; mod tcp; @@ -40,6 +46,7 @@ pub struct H1Client { #[cfg(any(feature = "native-tls", feature = "rustls"))] https_pools: HttpsPool, max_concurrent_connections: usize, + config: Arc, } impl Debug for H1Client { @@ -79,6 +86,7 @@ impl Debug for H1Client { "max_concurrent_connections", &self.max_concurrent_connections, ) + .field("config", &self.config) .finish() } } @@ -97,6 +105,7 @@ impl H1Client { #[cfg(any(feature = "native-tls", feature = "rustls"))] https_pools: DashMap::new(), max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS, + config: Arc::new(Config::default()), } } @@ -107,6 +116,7 @@ impl H1Client { #[cfg(any(feature = "native-tls", feature = "rustls"))] https_pools: DashMap::new(), max_concurrent_connections: max, + config: Arc::new(Config::default()), } } } @@ -147,12 +157,43 @@ impl HttpClient for H1Client { for (idx, addr) in addrs.into_iter().enumerate() { let has_another_addr = idx != max_addrs_idx; + #[cfg(feature = "unstable-config")] + if !self.config.http_keep_alive { + match scheme { + "http" => { + let stream = async_std::net::TcpStream::connect(addr).await?; + req.set_peer_addr(stream.peer_addr().ok()); + req.set_local_addr(stream.local_addr().ok()); + let tcp_conn = client::connect(stream, req); + return if let Some(timeout) = self.config.timeout { + async_std::future::timeout(timeout, tcp_conn).await? + } else { + tcp_conn.await + }; + } + #[cfg(any(feature = "native-tls", feature = "rustls"))] + "https" => { + let raw_stream = async_std::net::TcpStream::connect(addr).await?; + req.set_peer_addr(raw_stream.peer_addr().ok()); + req.set_local_addr(raw_stream.local_addr().ok()); + let tls_stream = tls::add_tls(&host, raw_stream, &self.config).await?; + let tsl_conn = client::connect(tls_stream, req); + return if let Some(timeout) = self.config.timeout { + async_std::future::timeout(timeout, tsl_conn).await? + } else { + tsl_conn.await + }; + } + _ => unreachable!(), + } + } + match scheme { "http" => { let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) { pool_ref } else { - let manager = TcpConnection::new(addr); + let manager = TcpConnection::new(addr, self.config.clone()); let pool = Pool::::new( manager, self.max_concurrent_connections, @@ -168,19 +209,28 @@ impl HttpClient for H1Client { let stream = match pool.get().await { Ok(s) => s, Err(_) if has_another_addr => continue, - Err(e) => return Err(Error::from_str(400, e.to_string()))?, + Err(e) => return Err(Error::from_str(400, e.to_string())), }; req.set_peer_addr(stream.peer_addr().ok()); req.set_local_addr(stream.local_addr().ok()); - return client::connect(TcpConnWrapper::new(stream), req).await; + + let tcp_conn = client::connect(TcpConnWrapper::new(stream), req); + #[cfg(feature = "unstable-config")] + return if let Some(timeout) = self.config.timeout { + async_std::future::timeout(timeout, tcp_conn).await? + } else { + tcp_conn.await + }; + #[cfg(not(feature = "unstable-config"))] + return tcp_conn.await; } #[cfg(any(feature = "native-tls", feature = "rustls"))] "https" => { let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) { pool_ref } else { - let manager = TlsConnection::new(host.clone(), addr); + let manager = TlsConnection::new(host.clone(), addr, self.config.clone()); let pool = Pool::, Error>::new( manager, self.max_concurrent_connections, @@ -196,13 +246,21 @@ impl HttpClient for H1Client { let stream = match pool.get().await { Ok(s) => s, Err(_) if has_another_addr => continue, - Err(e) => return Err(Error::from_str(400, e.to_string()))?, + Err(e) => return Err(Error::from_str(400, e.to_string())), }; req.set_peer_addr(stream.get_ref().peer_addr().ok()); req.set_local_addr(stream.get_ref().local_addr().ok()); - return client::connect(TlsConnWrapper::new(stream), req).await; + let tls_conn = client::connect(TlsConnWrapper::new(stream), req); + #[cfg(feature = "unstable-config")] + return if let Some(timeout) = self.config.timeout { + async_std::future::timeout(timeout, tls_conn).await? + } else { + tls_conn.await + }; + #[cfg(not(feature = "unstable-config"))] + return tls_conn.await; } _ => unreachable!(), } @@ -213,6 +271,37 @@ impl HttpClient for H1Client { "missing valid address", )) } + + #[cfg(feature = "unstable-config")] + /// Override the existing configuration with new configuration. + /// + /// Config options may not impact existing connections. + fn set_config(&mut self, config: Config) -> http_types::Result<()> { + self.config = Arc::new(config); + + Ok(()) + } + + #[cfg(feature = "unstable-config")] + /// Get the current configuration. + fn config(&self) -> &Config { + &*self.config + } +} + +#[cfg(feature = "unstable-config")] +impl TryFrom for H1Client { + type Error = Infallible; + + fn try_from(config: Config) -> Result { + Ok(Self { + http_pools: DashMap::new(), + #[cfg(any(feature = "native-tls", feature = "rustls"))] + https_pools: DashMap::new(), + max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS, + config: Arc::new(config), + }) + } } #[cfg(test)] diff --git a/src/h1/tcp.rs b/src/h1/tcp.rs index d99e13d..6b855fd 100644 --- a/src/h1/tcp.rs +++ b/src/h1/tcp.rs @@ -1,6 +1,6 @@ -use std::fmt::Debug; use std::net::SocketAddr; use std::pin::Pin; +use std::sync::Arc; use async_std::net::TcpStream; use async_trait::async_trait; @@ -8,13 +8,18 @@ use deadpool::managed::{Manager, Object, RecycleResult}; use futures::io::{AsyncRead, AsyncWrite}; use futures::task::{Context, Poll}; -#[derive(Clone, Debug)] +use crate::Config; + +#[derive(Clone)] +#[cfg_attr(not(feature = "rustls"), derive(std::fmt::Debug))] pub(crate) struct TcpConnection { addr: SocketAddr, + config: Arc, } + impl TcpConnection { - pub(crate) fn new(addr: SocketAddr) -> Self { - Self { addr } + pub(crate) fn new(addr: SocketAddr, config: Arc) -> Self { + Self { addr, config } } } @@ -58,12 +63,21 @@ impl AsyncWrite for TcpConnWrapper { #[async_trait] impl Manager for TcpConnection { async fn create(&self) -> Result { - TcpStream::connect(self.addr).await + let tcp_stream = TcpStream::connect(self.addr).await?; + + #[cfg(feature = "unstable-config")] + tcp_stream.set_nodelay(self.config.tcp_no_delay)?; + + Ok(tcp_stream) } async fn recycle(&self, conn: &mut TcpStream) -> RecycleResult { let mut buf = [0; 4]; let mut cx = Context::from_waker(futures::task::noop_waker_ref()); + + #[cfg(feature = "unstable-config")] + conn.set_nodelay(self.config.tcp_no_delay)?; + match Pin::new(conn).poll_read(&mut cx, &mut buf) { Poll::Ready(Err(error)) => Err(error), Poll::Ready(Ok(bytes)) if bytes == 0 => Err(std::io::Error::new( diff --git a/src/h1/tls.rs b/src/h1/tls.rs index f7c714c..796936c 100644 --- a/src/h1/tls.rs +++ b/src/h1/tls.rs @@ -1,6 +1,6 @@ -use std::fmt::Debug; use std::net::SocketAddr; use std::pin::Pin; +use std::sync::Arc; use async_std::net::TcpStream; use async_trait::async_trait; @@ -16,16 +16,19 @@ cfg_if::cfg_if! { } } -use crate::Error; +use crate::{Config, Error}; -#[derive(Clone, Debug)] +#[derive(Clone)] +#[cfg_attr(not(feature = "rustls"), derive(std::fmt::Debug))] pub(crate) struct TlsConnection { host: String, addr: SocketAddr, + config: Arc, } + impl TlsConnection { - pub(crate) fn new(host: String, addr: SocketAddr) -> Self { - Self { host, addr } + pub(crate) fn new(host: String, addr: SocketAddr, config: Arc) -> Self { + Self { host, addr, config } } } @@ -70,13 +73,23 @@ impl AsyncWrite for TlsConnWrapper { impl Manager, Error> for TlsConnection { async fn create(&self) -> Result, Error> { let raw_stream = async_std::net::TcpStream::connect(self.addr).await?; - let tls_stream = add_tls(&self.host, raw_stream).await?; + + #[cfg(feature = "unstable-config")] + raw_stream.set_nodelay(self.config.tcp_no_delay)?; + + let tls_stream = add_tls(&self.host, raw_stream, &self.config).await?; Ok(tls_stream) } async fn recycle(&self, conn: &mut TlsStream) -> RecycleResult { let mut buf = [0; 4]; let mut cx = Context::from_waker(futures::task::noop_waker_ref()); + + #[cfg(feature = "unstable-config")] + conn.get_ref() + .set_nodelay(self.config.tcp_no_delay) + .map_err(Error::from)?; + match Pin::new(conn).poll_read(&mut cx, &mut buf) { Poll::Ready(Err(error)) => Err(error), Poll::Ready(Ok(bytes)) if bytes == 0 => Err(std::io::Error::new( @@ -86,22 +99,39 @@ impl Manager, Error> for TlsConnection { _ => Ok(()), } .map_err(Error::from)?; + Ok(()) } } cfg_if::cfg_if! { if #[cfg(feature = "rustls")] { - async fn add_tls(host: &str, stream: TcpStream) -> Result, std::io::Error> { + #[allow(unused_variables)] + pub(crate) async fn add_tls(host: &str, stream: TcpStream, config: &Config) -> Result, std::io::Error> { + #[cfg(all(feature = "h1_client", feature = "unstable-config"))] + let connector = if let Some(tls_config) = config.tls_config.as_ref().cloned() { + tls_config.into() + } else { + async_tls::TlsConnector::default() + }; + #[cfg(not(feature = "unstable-config"))] let connector = async_tls::TlsConnector::default(); + connector.connect(host, stream).await } } else if #[cfg(feature = "native-tls")] { - async fn add_tls( + #[allow(unused_variables)] + pub(crate) async fn add_tls( host: &str, stream: TcpStream, + config: &Config, ) -> Result, async_native_tls::Error> { - async_native_tls::connect(host, stream).await + #[cfg(feature = "unstable-config")] + let connector = config.tls_config.as_ref().cloned().unwrap_or_default(); + #[cfg(not(feature = "unstable-config"))] + let connector = async_native_tls::TlsConnector::new(); + + connector.connect(host, stream).await } } } diff --git a/src/hyper.rs b/src/hyper.rs index 44f816b..02cb416 100644 --- a/src/hyper.rs +++ b/src/hyper.rs @@ -1,16 +1,22 @@ //! http-client implementation for reqwest -use super::{async_trait, Error, HttpClient, Request, Response}; +#[cfg(feature = "unstable-config")] +use std::convert::Infallible; +use std::convert::TryFrom; +use std::fmt::Debug; +use std::io; +use std::str::FromStr; + use futures_util::stream::TryStreamExt; use http_types::headers::{HeaderName, HeaderValue}; use http_types::StatusCode; use hyper::body::HttpBody; use hyper::client::connect::Connect; use hyper_tls::HttpsConnector; -use std::convert::TryFrom; -use std::fmt::Debug; -use std::io; -use std::str::FromStr; + +use crate::Config; + +use super::{async_trait, Error, HttpClient, Request, Response}; type HyperRequest = hyper::Request; @@ -27,14 +33,21 @@ impl HyperClientObject for h /// Hyper-based HTTP Client. #[derive(Debug)] -pub struct HyperClient(Box); +pub struct HyperClient { + client: Box, + config: Config, +} impl HyperClient { /// Create a new client instance. pub fn new() -> Self { let https = HttpsConnector::new(); let client = hyper::Client::builder().build(https); - Self(Box::new(client)) + + Self { + client: Box::new(client), + config: Config::default(), + } } /// Create from externally initialized and configured client. @@ -42,7 +55,10 @@ impl HyperClient { where C: Clone + Connect + Debug + Send + Sync + 'static, { - Self(Box::new(client)) + Self { + client: Box::new(client), + config: Config::default(), + } } } @@ -57,11 +73,67 @@ impl HttpClient for HyperClient { async fn send(&self, req: Request) -> Result { let req = HyperHttpRequest::try_from(req).await?.into_inner(); - let response = self.0.dyn_request(req).await?; + let conn_fut = self.client.dyn_request(req); + #[cfg(feature = "unstable-config")] + let response = if let Some(timeout) = self.config.timeout { + match tokio::time::timeout(timeout, conn_fut).await { + Err(_elapsed) => Err(Error::from_str(400, "Client timed out")), + Ok(Ok(try_res)) => Ok(try_res), + Ok(Err(e)) => Err(e.into()), + }? + } else { + conn_fut.await? + }; + + #[cfg(not(feature = "unstable-config"))] + let response = conn_fut.await?; let res = HttpTypesResponse::try_from(response).await?.into_inner(); Ok(res) } + + #[cfg(feature = "unstable-config")] + /// Override the existing configuration with new configuration. + /// + /// Config options may not impact existing connections. + fn set_config(&mut self, config: Config) -> http_types::Result<()> { + let connector = HttpsConnector::new(); + let mut builder = hyper::Client::builder(); + + if !config.http_keep_alive { + builder.pool_max_idle_per_host(1); + } + + self.client = Box::new(builder.build(connector)); + self.config = config; + + Ok(()) + } + + #[cfg(feature = "unstable-config")] + /// Get the current configuration. + fn config(&self) -> &Config { + &self.config + } +} + +#[cfg(feature = "unstable-config")] +impl TryFrom for HyperClient { + type Error = Infallible; + + fn try_from(config: Config) -> Result { + let connector = HttpsConnector::new(); + let mut builder = hyper::Client::builder(); + + if !config.http_keep_alive { + builder.pool_max_idle_per_host(1); + } + + Ok(Self { + client: Box::new(builder.build(connector)), + config, + }) + } } struct HyperHttpRequest(HyperRequest); diff --git a/src/isahc.rs b/src/isahc.rs index 63c6f56..c5ff159 100644 --- a/src/isahc.rs +++ b/src/isahc.rs @@ -1,13 +1,23 @@ //! http-client implementation for isahc -use super::{async_trait, Body, Error, HttpClient, Request, Response}; +#[cfg(feature = "unstable-config")] +use std::convert::TryFrom; use async_std::io::BufReader; +#[cfg(feature = "unstable-config")] +use isahc::config::Configurable; use isahc::{http, ResponseExt}; +use crate::Config; + +use super::{async_trait, Body, Error, HttpClient, Request, Response}; + /// Curl-based HTTP Client. #[derive(Debug)] -pub struct IsahcClient(isahc::HttpClient); +pub struct IsahcClient { + client: isahc::HttpClient, + config: Config, +} impl Default for IsahcClient { fn default() -> Self { @@ -23,7 +33,10 @@ impl IsahcClient { /// Create from externally initialized and configured client. pub fn from_client(client: isahc::HttpClient) -> Self { - Self(client) + Self { + client, + config: Config::default(), + } } } @@ -45,7 +58,7 @@ impl HttpClient for IsahcClient { }; let request = builder.body(body).unwrap(); - let res = self.0.send_async(request).await.map_err(Error::from)?; + let res = self.client.send_async(request).await.map_err(Error::from)?; let maybe_metrics = res.metrics().cloned(); let (parts, body) = res.into_parts(); let body = Body::from_reader(BufReader::new(body), None); @@ -61,6 +74,59 @@ impl HttpClient for IsahcClient { response.set_body(body); Ok(response) } + + #[cfg(feature = "unstable-config")] + /// Override the existing configuration with new configuration. + /// + /// Config options may not impact existing connections. + fn set_config(&mut self, config: Config) -> http_types::Result<()> { + let mut builder = isahc::HttpClient::builder(); + + if !config.http_keep_alive { + builder = builder.connection_cache_size(0); + } + if config.tcp_no_delay { + builder = builder.tcp_nodelay(); + } + if let Some(timeout) = config.timeout { + builder = builder.timeout(timeout); + } + + self.client = builder.build()?; + self.config = config; + + Ok(()) + } + + #[cfg(feature = "unstable-config")] + /// Get the current configuration. + fn config(&self) -> &Config { + &self.config + } +} + +#[cfg(feature = "unstable-config")] +impl TryFrom for IsahcClient { + type Error = isahc::Error; + + fn try_from(config: Config) -> Result { + let mut builder = isahc::HttpClient::builder(); + + if !config.http_keep_alive { + builder = builder.connection_cache_size(0); + } + if config.tcp_no_delay { + builder = builder.tcp_nodelay(); + } + if let Some(timeout) = config.timeout { + builder = builder.timeout(timeout); + } + + Ok(Self { + client: builder.build()?, + config, + }) + } } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 2e9b103..3e60c25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,13 @@ forbid(unsafe_code) )] +#[cfg(feature = "unstable-config")] +mod config; +#[cfg(feature = "unstable-config")] +pub use config::Config; +#[cfg(not(feature = "unstable-config"))] +type Config = (); + #[cfg_attr(feature = "docs", doc(cfg(curl_client)))] #[cfg(all(feature = "curl_client", not(target_arch = "wasm32")))] pub mod isahc; @@ -60,6 +67,31 @@ pub use http_types; pub trait HttpClient: std::fmt::Debug + Unpin + Send + Sync + 'static { /// Perform a request. async fn send(&self, req: Request) -> Result; + + #[cfg(feature = "unstable-config")] + /// Override the existing configuration with new configuration. + /// + /// Config options may not impact existing connections. + fn set_config(&mut self, _config: Config) -> http_types::Result<()> { + unimplemented!( + "{} has not implemented `HttpClient::set_config()`", + type_name_of(self) + ) + } + + #[cfg(feature = "unstable-config")] + /// Get the current configuration. + fn config(&self) -> &Config { + unimplemented!( + "{} has not implemented `HttpClient::config()`", + type_name_of(self) + ) + } +} + +#[cfg(feature = "unstable-config")] +fn type_name_of(_val: &T) -> &'static str { + std::any::type_name::() } /// The raw body of an http request or response. @@ -70,7 +102,17 @@ pub type Error = http_types::Error; #[async_trait] impl HttpClient for Box { - async fn send(&self, req: Request) -> Result { + async fn send(&self, req: Request) -> http_types::Result { self.as_ref().send(req).await } + + #[cfg(feature = "unstable-config")] + fn set_config(&mut self, config: Config) -> http_types::Result<()> { + self.as_mut().set_config(config) + } + + #[cfg(feature = "unstable-config")] + fn config(&self) -> &Config { + self.as_ref().config() + } }