diff --git a/src/pg/error_helper.rs b/src/pg/error_helper.rs index 0b25f0e..9b7eb3c 100644 --- a/src/pg/error_helper.rs +++ b/src/pg/error_helper.rs @@ -1,3 +1,6 @@ +use std::error::Error; +use std::sync::Arc; + use diesel::ConnectionError; pub(super) struct ErrorHelper(pub(super) tokio_postgres::Error); @@ -10,40 +13,46 @@ impl From for ConnectionError { impl From for diesel::result::Error { fn from(ErrorHelper(postgres_error): ErrorHelper) -> Self { - use diesel::result::DatabaseErrorKind::*; - use tokio_postgres::error::SqlState; + from_tokio_postgres_error(Arc::new(postgres_error)) + } +} - match postgres_error.code() { - Some(code) => { - let kind = match *code { - SqlState::UNIQUE_VIOLATION => UniqueViolation, - SqlState::FOREIGN_KEY_VIOLATION => ForeignKeyViolation, - SqlState::T_R_SERIALIZATION_FAILURE => SerializationFailure, - SqlState::READ_ONLY_SQL_TRANSACTION => ReadOnlyTransaction, - SqlState::NOT_NULL_VIOLATION => NotNullViolation, - SqlState::CHECK_VIOLATION => CheckViolation, - _ => Unknown, - }; +pub(super) fn from_tokio_postgres_error( + postgres_error: Arc, +) -> diesel::result::Error { + use diesel::result::DatabaseErrorKind::*; + use tokio_postgres::error::SqlState; - diesel::result::Error::DatabaseError( - kind, - Box::new(PostgresDbErrorWrapper( - postgres_error - .into_source() - .and_then(|e| e.downcast::().ok()) - .expect("It's a db error, because we've got a SQLState code above"), - )) as _, - ) - } - None => diesel::result::Error::DatabaseError( - UnableToSendCommand, - Box::new(postgres_error.to_string()), - ), + match postgres_error.code() { + Some(code) => { + let kind = match *code { + SqlState::UNIQUE_VIOLATION => UniqueViolation, + SqlState::FOREIGN_KEY_VIOLATION => ForeignKeyViolation, + SqlState::T_R_SERIALIZATION_FAILURE => SerializationFailure, + SqlState::READ_ONLY_SQL_TRANSACTION => ReadOnlyTransaction, + SqlState::NOT_NULL_VIOLATION => NotNullViolation, + SqlState::CHECK_VIOLATION => CheckViolation, + _ => Unknown, + }; + + diesel::result::Error::DatabaseError( + kind, + Box::new(PostgresDbErrorWrapper( + postgres_error + .source() + .and_then(|e| e.downcast_ref::().cloned()) + .expect("It's a db error, because we've got a SQLState code above"), + )) as _, + ) } + None => diesel::result::Error::DatabaseError( + UnableToSendCommand, + Box::new(postgres_error.to_string()), + ), } } -struct PostgresDbErrorWrapper(Box); +struct PostgresDbErrorWrapper(tokio_postgres::error::DbError); impl diesel::result::DatabaseErrorInformation for PostgresDbErrorWrapper { fn message(&self) -> &str { diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 654874d..2432e15 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -16,12 +16,17 @@ use diesel::pg::{ }; use diesel::query_builder::bind_collector::RawBytesBindCollector; use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId}; +use diesel::result::DatabaseErrorKind; use diesel::{ConnectionError, ConnectionResult, QueryResult}; use futures_util::future::BoxFuture; +use futures_util::future::Either; use futures_util::stream::{BoxStream, TryStreamExt}; +use futures_util::TryFutureExt; use futures_util::{Future, FutureExt, StreamExt}; use std::borrow::Cow; use std::sync::Arc; +use tokio::sync::broadcast; +use tokio::sync::oneshot; use tokio::sync::Mutex; use tokio_postgres::types::ToSql; use tokio_postgres::types::Type; @@ -102,12 +107,20 @@ pub struct AsyncPgConnection { stmt_cache: Arc>>, transaction_state: Arc>, metadata_cache: Arc>, + connection_future: Option>>, + shutdown_channel: Option>, } #[async_trait::async_trait] impl SimpleAsyncConnection for AsyncPgConnection { async fn batch_execute(&mut self, query: &str) -> QueryResult<()> { - Ok(self.conn.batch_execute(query).await.map_err(ErrorHelper)?) + let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe()); + let batch_execute = self + .conn + .batch_execute(query) + .map_err(ErrorHelper) + .map_err(Into::into); + drive_future(connection_future, batch_execute).await } } @@ -124,12 +137,18 @@ impl AsyncConnection for AsyncPgConnection { let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls) .await .map_err(ErrorHelper)?; + let (tx, rx) = tokio::sync::broadcast::channel(1); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("connection error: {e}"); + match futures_util::future::select(shutdown_rx, connection).await { + Either::Left(_) | Either::Right((Ok(_), _)) => {} + Either::Right((Err(e), _)) => { + let _ = tx.send(Arc::new(e)); + } } }); - Self::try_from(client).await + + Self::setup(client, Some(rx), Some(shutdown_tx)).await } fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> @@ -137,16 +156,18 @@ impl AsyncConnection for AsyncPgConnection { T: AsQuery + 'query, T::Query: QueryFragment + QueryId + 'query, { + let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe()); let query = source.as_query(); - self.with_prepared_statement(query, |conn, stmt, binds| async move { + let load_future = self.with_prepared_statement(query, |conn, stmt, binds| async move { let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?; Ok(res .map_err(|e| diesel::result::Error::from(ErrorHelper(e))) .map_ok(PgRow::new) .boxed()) - }) - .boxed() + }); + + drive_future(connection_future, load_future).boxed() } fn execute_returning_count<'conn, 'query, T>( @@ -156,7 +177,8 @@ impl AsyncConnection for AsyncPgConnection { where T: QueryFragment + QueryId + 'query, { - self.with_prepared_statement(source, |conn, stmt, binds| async move { + let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe()); + let execute = self.with_prepared_statement(source, |conn, stmt, binds| async move { let binds = binds .iter() .map(|b| b as &(dyn ToSql + Sync)) @@ -166,8 +188,8 @@ impl AsyncConnection for AsyncPgConnection { .await .map_err(ErrorHelper)?; Ok(res as usize) - }) - .boxed() + }); + drive_future(connection_future, execute).boxed() } fn transaction_state(&mut self) -> &mut AnsiTransactionManager { @@ -182,15 +204,21 @@ impl AsyncConnection for AsyncPgConnection { } } +impl Drop for AsyncPgConnection { + fn drop(&mut self) { + if let Some(tx) = self.shutdown_channel.take() { + let _ = tx.send(()); + } + } +} + #[inline(always)] fn update_transaction_manager_status( query_result: QueryResult, transaction_manager: &mut AnsiTransactionManager, ) -> QueryResult { - if let Err(diesel::result::Error::DatabaseError( - diesel::result::DatabaseErrorKind::SerializationFailure, - _, - )) = query_result + if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) = + query_result { transaction_manager .status @@ -270,11 +298,21 @@ impl AsyncPgConnection { /// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`] pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult { + Self::setup(conn, None, None).await + } + + async fn setup( + conn: tokio_postgres::Client, + connection_future: Option>>, + shutdown_channel: Option>, + ) -> ConnectionResult { let mut conn = Self { conn: Arc::new(conn), stmt_cache: Arc::new(Mutex::new(StmtCache::new())), transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())), metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())), + connection_future, + shutdown_channel, }; conn.set_config_options() .await @@ -470,6 +508,29 @@ async fn lookup_type( Ok((r.get(0), r.get(1))) } +async fn drive_future( + connection_future: Option>>, + client_future: impl Future>, +) -> Result { + if let Some(mut connection_future) = connection_future { + let client_future = std::pin::pin!(client_future); + let connection_future = std::pin::pin!(connection_future.recv()); + match futures_util::future::select(client_future, connection_future).await { + Either::Left((res, _)) => res, + // we got an error from the background task + // return it to the user + Either::Right((Ok(e), _)) => Err(self::error_helper::from_tokio_postgres_error(e)), + // seems like the background thread died for whatever reason + Either::Right((Err(e), _)) => Err(diesel::result::Error::DatabaseError( + DatabaseErrorKind::UnableToSendCommand, + Box::new(e.to_string()), + )), + } + } else { + client_future.await + } +} + #[cfg(any( feature = "deadpool", feature = "bb8",