diff --git a/examples/postgres/pooled-with-rustls/src/main.rs b/examples/postgres/pooled-with-rustls/src/main.rs index a18451c..87a8eb4 100644 --- a/examples/postgres/pooled-with-rustls/src/main.rs +++ b/examples/postgres/pooled-with-rustls/src/main.rs @@ -49,12 +49,8 @@ fn establish_connection(config: &str) -> BoxFuture BoxFuture, stmt_cache: Arc>>, @@ -156,24 +170,17 @@ impl AsyncConnection for AsyncPgConnection { let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls) .await .map_err(ErrorHelper)?; - let (tx, rx) = tokio::sync::broadcast::channel(1); - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); - tokio::spawn(async move { - match futures_util::future::select(shutdown_rx, connection).await { - Either::Left(_) | Either::Right((Ok(_), _)) => {} - Either::Right((Err(e), _)) => { - let _ = tx.send(Arc::new(e)); - } - } - }); + + let (error_rx, shutdown_tx) = drive_connection(connection); let r = Self::setup( client, - Some(rx), + Some(error_rx), Some(shutdown_tx), Arc::clone(&instrumentation), ) .await; + instrumentation .lock() .unwrap_or_else(|e| e.into_inner()) @@ -367,6 +374,28 @@ impl AsyncPgConnection { .await } + /// Constructs a new `AsyncPgConnection` from an existing [`tokio_postgres::Client`] and + /// [`tokio_postgres::Connection`] + pub async fn try_from_client_and_connection( + client: tokio_postgres::Client, + conn: tokio_postgres::Connection, + ) -> ConnectionResult + where + S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static, + { + let (error_rx, shutdown_tx) = drive_connection(conn); + + Self::setup( + client, + Some(error_rx), + Some(shutdown_tx), + Arc::new(std::sync::Mutex::new( + diesel::connection::get_default_instrumentation(), + )), + ) + .await + } + async fn setup( conn: tokio_postgres::Client, connection_future: Option>>, @@ -826,6 +855,30 @@ async fn drive_future( } } +fn drive_connection( + conn: tokio_postgres::Connection, +) -> ( + broadcast::Receiver>, + oneshot::Sender<()>, +) +where + S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static, +{ + let (error_tx, error_rx) = tokio::sync::broadcast::channel(1); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + + tokio::spawn(async move { + match futures_util::future::select(shutdown_rx, conn).await { + Either::Left(_) | Either::Right((Ok(_), _)) => {} + Either::Right((Err(e), _)) => { + let _ = error_tx.send(Arc::new(e)); + } + } + }); + + (error_rx, shutdown_tx) +} + #[cfg(any( feature = "deadpool", feature = "bb8",