Skip to content

Better handling of the postgres connection background task #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 37 additions & 28 deletions src/pg/error_helper.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::error::Error;
use std::sync::Arc;

use diesel::ConnectionError;

pub(super) struct ErrorHelper(pub(super) tokio_postgres::Error);
Expand All @@ -10,40 +13,46 @@ impl From<ErrorHelper> for ConnectionError {

impl From<ErrorHelper> for diesel::result::Error {
fn from(ErrorHelper(postgres_error): ErrorHelper) -> Self {
use diesel::result::DatabaseErrorKind::*;
use tokio_postgres::error::SqlState;
from_tokio_postgres_error(Arc::new(postgres_error))
}
}

match postgres_error.code() {
Some(code) => {
let kind = match *code {
SqlState::UNIQUE_VIOLATION => UniqueViolation,
SqlState::FOREIGN_KEY_VIOLATION => ForeignKeyViolation,
SqlState::T_R_SERIALIZATION_FAILURE => SerializationFailure,
SqlState::READ_ONLY_SQL_TRANSACTION => ReadOnlyTransaction,
SqlState::NOT_NULL_VIOLATION => NotNullViolation,
SqlState::CHECK_VIOLATION => CheckViolation,
_ => Unknown,
};
pub(super) fn from_tokio_postgres_error(
postgres_error: Arc<tokio_postgres::Error>,
) -> diesel::result::Error {
use diesel::result::DatabaseErrorKind::*;
use tokio_postgres::error::SqlState;

diesel::result::Error::DatabaseError(
kind,
Box::new(PostgresDbErrorWrapper(
postgres_error
.into_source()
.and_then(|e| e.downcast::<tokio_postgres::error::DbError>().ok())
.expect("It's a db error, because we've got a SQLState code above"),
)) as _,
)
}
None => diesel::result::Error::DatabaseError(
UnableToSendCommand,
Box::new(postgres_error.to_string()),
),
match postgres_error.code() {
Some(code) => {
let kind = match *code {
SqlState::UNIQUE_VIOLATION => UniqueViolation,
SqlState::FOREIGN_KEY_VIOLATION => ForeignKeyViolation,
SqlState::T_R_SERIALIZATION_FAILURE => SerializationFailure,
SqlState::READ_ONLY_SQL_TRANSACTION => ReadOnlyTransaction,
SqlState::NOT_NULL_VIOLATION => NotNullViolation,
SqlState::CHECK_VIOLATION => CheckViolation,
_ => Unknown,
};

diesel::result::Error::DatabaseError(
kind,
Box::new(PostgresDbErrorWrapper(
postgres_error
.source()
.and_then(|e| e.downcast_ref::<tokio_postgres::error::DbError>().cloned())
.expect("It's a db error, because we've got a SQLState code above"),
)) as _,
)
}
None => diesel::result::Error::DatabaseError(
UnableToSendCommand,
Box::new(postgres_error.to_string()),
),
}
}

struct PostgresDbErrorWrapper(Box<tokio_postgres::error::DbError>);
struct PostgresDbErrorWrapper(tokio_postgres::error::DbError);

impl diesel::result::DatabaseErrorInformation for PostgresDbErrorWrapper {
fn message(&self) -> &str {
Expand Down
89 changes: 75 additions & 14 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ use diesel::pg::{
};
use diesel::query_builder::bind_collector::RawBytesBindCollector;
use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId};
use diesel::result::DatabaseErrorKind;
use diesel::{ConnectionError, ConnectionResult, QueryResult};
use futures_util::future::BoxFuture;
use futures_util::future::Either;
use futures_util::stream::{BoxStream, TryStreamExt};
use futures_util::TryFutureExt;
use futures_util::{Future, FutureExt, StreamExt};
use std::borrow::Cow;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::oneshot;
use tokio::sync::Mutex;
use tokio_postgres::types::ToSql;
use tokio_postgres::types::Type;
Expand Down Expand Up @@ -102,12 +107,20 @@ pub struct AsyncPgConnection {
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
metadata_cache: Arc<Mutex<PgMetadataCache>>,
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
shutdown_channel: Option<oneshot::Sender<()>>,
}

#[async_trait::async_trait]
impl SimpleAsyncConnection for AsyncPgConnection {
async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
Ok(self.conn.batch_execute(query).await.map_err(ErrorHelper)?)
let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
let batch_execute = self
.conn
.batch_execute(query)
.map_err(ErrorHelper)
.map_err(Into::into);
drive_future(connection_future, batch_execute).await
}
}

Expand All @@ -124,29 +137,37 @@ impl AsyncConnection for AsyncPgConnection {
let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
.await
.map_err(ErrorHelper)?;
let (tx, rx) = tokio::sync::broadcast::channel(1);
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {e}");
match futures_util::future::select(shutdown_rx, connection).await {
Either::Left(_) | Either::Right((Ok(_), _)) => {}
Either::Right((Err(e), _)) => {
let _ = tx.send(Arc::new(e));
}
}
});
Self::try_from(client).await

Self::setup(client, Some(rx), Some(shutdown_tx)).await
}

fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
where
T: AsQuery + 'query,
T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
{
let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
let query = source.as_query();
self.with_prepared_statement(query, |conn, stmt, binds| async move {
let load_future = self.with_prepared_statement(query, |conn, stmt, binds| async move {
let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;

Ok(res
.map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
.map_ok(PgRow::new)
.boxed())
})
.boxed()
});

drive_future(connection_future, load_future).boxed()
}

fn execute_returning_count<'conn, 'query, T>(
Expand All @@ -156,7 +177,8 @@ impl AsyncConnection for AsyncPgConnection {
where
T: QueryFragment<Self::Backend> + QueryId + 'query,
{
self.with_prepared_statement(source, |conn, stmt, binds| async move {
let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
let execute = self.with_prepared_statement(source, |conn, stmt, binds| async move {
let binds = binds
.iter()
.map(|b| b as &(dyn ToSql + Sync))
Expand All @@ -166,8 +188,8 @@ impl AsyncConnection for AsyncPgConnection {
.await
.map_err(ErrorHelper)?;
Ok(res as usize)
})
.boxed()
});
drive_future(connection_future, execute).boxed()
}

fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
Expand All @@ -182,15 +204,21 @@ impl AsyncConnection for AsyncPgConnection {
}
}

impl Drop for AsyncPgConnection {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_channel.take() {
let _ = tx.send(());
}
}
}

#[inline(always)]
fn update_transaction_manager_status<T>(
query_result: QueryResult<T>,
transaction_manager: &mut AnsiTransactionManager,
) -> QueryResult<T> {
if let Err(diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::SerializationFailure,
_,
)) = query_result
if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
query_result
{
transaction_manager
.status
Expand Down Expand Up @@ -270,11 +298,21 @@ impl AsyncPgConnection {

/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
Self::setup(conn, None, None).await
}

async fn setup(
conn: tokio_postgres::Client,
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
shutdown_channel: Option<oneshot::Sender<()>>,
) -> ConnectionResult<Self> {
let mut conn = Self {
conn: Arc::new(conn),
stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
connection_future,
shutdown_channel,
};
conn.set_config_options()
.await
Expand Down Expand Up @@ -470,6 +508,29 @@ async fn lookup_type(
Ok((r.get(0), r.get(1)))
}

async fn drive_future<R>(
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
client_future: impl Future<Output = Result<R, diesel::result::Error>>,
) -> Result<R, diesel::result::Error> {
if let Some(mut connection_future) = connection_future {
let client_future = std::pin::pin!(client_future);
let connection_future = std::pin::pin!(connection_future.recv());
match futures_util::future::select(client_future, connection_future).await {
Either::Left((res, _)) => res,
// we got an error from the background task
// return it to the user
Either::Right((Ok(e), _)) => Err(self::error_helper::from_tokio_postgres_error(e)),
// seems like the background thread died for whatever reason
Either::Right((Err(e), _)) => Err(diesel::result::Error::DatabaseError(
DatabaseErrorKind::UnableToSendCommand,
Box::new(e.to_string()),
)),
}
} else {
client_future.await
}
}

#[cfg(any(
feature = "deadpool",
feature = "bb8",
Expand Down