diff --git a/sqlx-core/src/ext/async_stream.rs b/sqlx-core/src/ext/async_stream.rs index a392609602..54c0d0e3de 100644 --- a/sqlx-core/src/ext/async_stream.rs +++ b/sqlx-core/src/ext/async_stream.rs @@ -62,9 +62,9 @@ macro_rules! try_stream { ($($block:tt)*) => { crate::ext::async_stream::TryAsyncStream::new(move |mut sender| async move { macro_rules! r#yield { - ($v:expr) => { + ($v:expr) => {{ let _ = futures_util::sink::SinkExt::send(&mut sender, Ok($v)).await; - } + }} } $($block)* diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index 6b5b55a4ae..8f376cbfb0 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -15,7 +15,7 @@ pub struct BufStream where S: AsyncRead + AsyncWrite + Unpin, { - stream: S, + pub(crate) stream: S, // writes with `write` to the underlying stream are buffered // this can be flushed with `flush` diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs index ea49a532cf..5241af0210 100644 --- a/sqlx-core/src/postgres/connection/mod.rs +++ b/sqlx-core/src/postgres/connection/mod.rs @@ -11,7 +11,6 @@ use crate::error::Error; use crate::executor::Executor; use crate::ext::ustr::UStr; use crate::io::Decode; -use crate::postgres::connection::stream::PgStream; use crate::postgres::message::{ Close, Message, MessageFormat, ReadyForQuery, Terminate, TransactionStatus, }; @@ -19,6 +18,8 @@ use crate::postgres::statement::PgStatementMetadata; use crate::postgres::{PgConnectOptions, PgTypeInfo, Postgres}; use crate::transaction::Transaction; +pub use self::stream::PgStream; + pub(crate) mod describe; mod establish; mod executor; @@ -66,7 +67,7 @@ pub struct PgConnection { impl PgConnection { // will return when the connection is ready for another query - async fn wait_until_ready(&mut self) -> Result<(), Error> { + pub(in crate::postgres) async fn wait_until_ready(&mut self) -> Result<(), Error> { if !self.stream.wbuf.is_empty() { self.stream.flush().await?; } diff --git a/sqlx-core/src/postgres/copy.rs b/sqlx-core/src/postgres/copy.rs new file mode 100644 index 0000000000..0d8eb46f2d --- /dev/null +++ b/sqlx-core/src/postgres/copy.rs @@ -0,0 +1,317 @@ +use crate::error::{Error, Result}; +use crate::ext::async_stream::TryAsyncStream; +use crate::pool::{Pool, PoolConnection}; +use crate::postgres::connection::PgConnection; +use crate::postgres::message::{ + CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query, +}; +use crate::postgres::Postgres; +use bytes::{BufMut, Bytes}; +use futures_core::stream::BoxStream; +use smallvec::alloc::borrow::Cow; +use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt}; +use std::convert::TryFrom; +use std::ops::{Deref, DerefMut}; + +impl PgConnection { + /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data + /// to Postgres. This is a more efficient way to import data into Postgres as compared to + /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is + /// returned. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + /// + /// ### Note + /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection + /// will return an error the next time it is used. + pub async fn copy_in_raw(&mut self, statement: &str) -> Result> { + PgCopyIn::begin(self, statement).await + } + + /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data + /// from Postgres. This is a more efficient way to export data from Postgres but + /// arrives in chunks of one of a few data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, + /// an error is returned. + /// + /// Note that once this process has begun, unless you read the stream to completion, + /// it can only be canceled in two ways: + /// + /// 1. by closing the connection, or: + /// 2. by using another connection to kill the server process that is sending the data as shown + /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). + /// + /// If you don't read the stream to completion, the next time the connection is used it will + /// need to read and discard all the remaining queued data, which could take some time. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + #[allow(clippy::needless_lifetimes)] + pub async fn copy_out_raw<'c>( + &'c mut self, + statement: &str, + ) -> Result>> { + pg_begin_copy_out(self, statement).await + } +} + +impl Pool { + /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres. + /// This is a more efficient way to import data into Postgres as compared to + /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). + /// + /// A single connection will be checked out for the duration. + /// + /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is + /// returned. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + /// + /// ### Note + /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection + /// will return an error the next time it is used. + pub async fn copy_in_raw( + &mut self, + statement: &str, + ) -> Result>> { + PgCopyIn::begin(self.acquire().await?, statement).await + } + + /// Issue a `COPY TO STDOUT` statement and begin streaming data + /// from Postgres. This is a more efficient way to export data from Postgres but + /// arrives in chunks of one of a few data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, + /// an error is returned. + /// + /// Note that once this process has begun, unless you read the stream to completion, + /// it can only be canceled in two ways: + /// + /// 1. by closing the connection, or: + /// 2. by using another connection to kill the server process that is sending the data as shown + /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). + /// + /// If you don't read the stream to completion, the next time the connection is used it will + /// need to read and discard all the remaining queued data, which could take some time. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + pub async fn copy_out_raw( + &mut self, + statement: &str, + ) -> Result>> { + pg_begin_copy_out(self.acquire().await?, statement).await + } +} + +/// A connection in streaming `COPY FROM STDIN` mode. +/// +/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw]. +/// +/// ### Note +/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection +/// will return an error the next time it is used. +#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"] +pub struct PgCopyIn> { + conn: Option, + response: CopyResponse, +} + +impl> PgCopyIn { + async fn begin(mut conn: C, statement: &str) -> Result { + conn.wait_until_ready().await?; + conn.stream.send(Query(statement)).await?; + + let response: CopyResponse = conn + .stream + .recv_expect(MessageFormat::CopyInResponse) + .await?; + + Ok(PgCopyIn { + conn: Some(conn), + response, + }) + } + + /// Send a chunk of `COPY` data. + /// + /// If you're copying data from an `AsyncRead`, maybe consider [Self::copy_from] instead. + pub async fn send(&mut self, data: impl Deref) -> Result<&mut Self> { + self.conn + .as_deref_mut() + .expect("send_data: conn taken") + .stream + .send(CopyData(data)) + .await?; + + Ok(self) + } + + /// Copy data directly from `source` to the database without requiring an intermediate buffer. + /// + /// `source` will be read to the end. + /// + /// ### Note + /// You must still call either [Self::finish] or [Self::abort] to complete the process. + pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> { + // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing + struct BufGuard<'s>(&'s mut Vec); + + impl Drop for BufGuard<'_> { + fn drop(&mut self) { + self.0.clear() + } + } + + let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken"); + + // flush any existing messages in the buffer and clear it + conn.stream.flush().await?; + + { + let buf_stream = &mut *conn.stream; + let stream = &mut buf_stream.stream; + + // ensures the buffer isn't left in an inconsistent state + let mut guard = BufGuard(&mut buf_stream.wbuf); + + let buf: &mut Vec = &mut guard.0; + buf.push(b'd'); // CopyData format code + buf.resize(5, 0); // reserve space for the length + + loop { + let read = match () { + // Tokio lets us read into the buffer without zeroing first + #[cfg(any(feature = "runtime-tokio", feature = "runtime-actix"))] + _ if buf.len() != buf.capacity() => { + // in case we have some data in the buffer, which can occur + // if the previous write did not fill the buffer + buf.truncate(5); + source.read_buf(buf).await? + } + _ => { + // should be a no-op unless len != capacity + buf.resize(buf.capacity(), 0); + source.read(&mut buf[5..]).await? + } + }; + + if read == 0 { + break; + } + + let read32 = u32::try_from(read) + .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?; + + (&mut buf[1..]).put_u32(read32 + 4); + + stream.write_all(&buf[..read + 5]).await?; + stream.flush().await?; + } + } + + Ok(self) + } + + /// Signal that the `COPY` process should be aborted and any data received should be discarded. + /// + /// The given message can be used for indicating the reason for the abort in the database logs. + /// + /// The server is expected to respond with an error, so only _unexpected_ errors are returned. + pub async fn abort(mut self, msg: impl Into) -> Result<()> { + let mut conn = self + .conn + .take() + .expect("PgCopyIn::fail_with: conn taken illegally"); + + conn.stream.send(CopyFail::new(msg)).await?; + + match conn.stream.recv().await { + Ok(msg) => Err(err_protocol!( + "fail_with: expected ErrorResponse, got: {:?}", + msg.format + )), + Err(Error::Database(e)) => { + match e.code() { + Some(Cow::Borrowed("57014")) => { + // postgres abort received error code + conn.stream + .recv_expect(MessageFormat::ReadyForQuery) + .await?; + Ok(()) + } + _ => Err(Error::Database(e)), + } + } + Err(e) => Err(e), + } + } + + /// Signal that the `COPY` process is complete. + /// + /// The number of rows affected is returned. + pub async fn finish(mut self) -> Result { + let mut conn = self + .conn + .take() + .expect("CopyWriter::finish: conn taken illegally"); + + conn.stream.send(CopyDone).await?; + let cc: CommandComplete = conn + .stream + .recv_expect(MessageFormat::CommandComplete) + .await?; + + conn.stream + .recv_expect(MessageFormat::ReadyForQuery) + .await?; + + Ok(cc.rows_affected()) + } +} + +impl> Drop for PgCopyIn { + fn drop(&mut self) { + if let Some(mut conn) = self.conn.take() { + conn.stream.write(CopyFail::new( + "PgCopyIn dropped without calling finish() or fail()", + )); + } + } +} + +async fn pg_begin_copy_out<'c, C: DerefMut + Send + 'c>( + mut conn: C, + statement: &str, +) -> Result>> { + conn.wait_until_ready().await?; + conn.stream.send(Query(statement)).await?; + + let _: CopyResponse = conn + .stream + .recv_expect(MessageFormat::CopyOutResponse) + .await?; + + let stream: TryAsyncStream<'c, Bytes> = try_stream! { + loop { + let msg = conn.stream.recv().await?; + match msg.format { + MessageFormat::CopyData => r#yield!(msg.decode::>()?.0), + MessageFormat::CopyDone => { + let _ = msg.decode::()?; + conn.stream.recv_expect(MessageFormat::CommandComplete).await?; + conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; + return Ok(()) + }, + _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) + } + } + }; + + Ok(Box::pin(stream)) +} diff --git a/sqlx-core/src/postgres/message/copy.rs b/sqlx-core/src/postgres/message/copy.rs new file mode 100644 index 0000000000..58553d431b --- /dev/null +++ b/sqlx-core/src/postgres/message/copy.rs @@ -0,0 +1,96 @@ +use crate::error::Result; +use crate::io::{BufExt, BufMutExt, Decode, Encode}; +use bytes::{Buf, BufMut, Bytes}; +use std::ops::Deref; + +/// The same structure is sent for both `CopyInResponse` and `CopyOutResponse` +pub struct CopyResponse { + pub format: i8, + pub num_columns: i16, + pub format_codes: Vec, +} + +pub struct CopyData(pub B); + +pub struct CopyFail { + pub message: String, +} + +pub struct CopyDone; + +impl Decode<'_> for CopyResponse { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let format = buf.get_i8(); + let num_columns = buf.get_i16(); + + let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect(); + + Ok(CopyResponse { + format, + num_columns, + format_codes, + }) + } +} + +impl Decode<'_> for CopyData { + fn decode_with(buf: Bytes, _: ()) -> Result { + // well.. that was easy + Ok(CopyData(buf)) + } +} + +impl> Encode<'_> for CopyData { + fn encode_with(&self, buf: &mut Vec, _context: ()) { + buf.push(b'd'); + buf.put_u32(self.0.len() as u32 + 4); + buf.extend_from_slice(&self.0); + } +} + +impl Decode<'_> for CopyFail { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + Ok(CopyFail { + message: buf.get_str_nul()?, + }) + } +} + +impl Encode<'_> for CopyFail { + fn encode_with(&self, buf: &mut Vec, _: ()) { + let len = 4 + self.message.len() + 1; + + buf.push(b'f'); // to pay respects + buf.put_u32(len as u32); + buf.put_str_nul(&self.message); + } +} + +impl CopyFail { + pub fn new(msg: impl Into) -> CopyFail { + CopyFail { + message: msg.into(), + } + } +} + +impl Decode<'_> for CopyDone { + fn decode_with(buf: Bytes, _: ()) -> Result { + if buf.is_empty() { + Ok(CopyDone) + } else { + Err(err_protocol!( + "expected no data for CopyDone, got: {:?}", + buf + )) + } + } +} + +impl Encode<'_> for CopyDone { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.reserve(4); + buf.push(b'c'); + buf.put_u32(4); + } +} diff --git a/sqlx-core/src/postgres/message/mod.rs b/sqlx-core/src/postgres/message/mod.rs index 91aa578911..1261bff339 100644 --- a/sqlx-core/src/postgres/message/mod.rs +++ b/sqlx-core/src/postgres/message/mod.rs @@ -8,6 +8,7 @@ mod backend_key_data; mod bind; mod close; mod command_complete; +mod copy; mod data_row; mod describe; mod execute; @@ -32,6 +33,7 @@ pub use backend_key_data::BackendKeyData; pub use bind::Bind; pub use close::Close; pub use command_complete::CommandComplete; +pub use copy::{CopyData, CopyDone, CopyFail, CopyResponse}; pub use data_row::DataRow; pub use describe::Describe; pub use execute::Execute; @@ -59,6 +61,10 @@ pub enum MessageFormat { BindComplete, CloseComplete, CommandComplete, + CopyData, + CopyDone, + CopyInResponse, + CopyOutResponse, DataRow, EmptyQueryResponse, ErrorResponse, @@ -98,6 +104,10 @@ impl MessageFormat { b'2' => MessageFormat::BindComplete, b'3' => MessageFormat::CloseComplete, b'C' => MessageFormat::CommandComplete, + b'd' => MessageFormat::CopyData, + b'c' => MessageFormat::CopyDone, + b'G' => MessageFormat::CopyInResponse, + b'H' => MessageFormat::CopyOutResponse, b'D' => MessageFormat::DataRow, b'E' => MessageFormat::ErrorResponse, b'I' => MessageFormat::EmptyQueryResponse, diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 4555b49566..ae71a18de5 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -5,6 +5,7 @@ use crate::executor::Executor; mod arguments; mod column; mod connection; +mod copy; mod database; mod error; mod io; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index a88c0085e7..9264ae21aa 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1,4 +1,4 @@ -use futures::TryStreamExt; +use futures::{StreamExt, TryStreamExt}; use sqlx::postgres::{ PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgSeverity, }; @@ -1106,6 +1106,100 @@ async fn test_pg_server_num() -> anyhow::Result<()> { } #[sqlx_macros::test] +async fn it_can_copy_in() -> anyhow::Result<()> { + let mut conn = new::().await?; + conn.execute( + r#" + CREATE TEMPORARY TABLE users (id INTEGER NOT NULL); + "#, + ) + .await?; + + let mut copy = conn + .copy_in_raw( + r#" + COPY users (id) FROM STDIN WITH (FORMAT CSV, HEADER); + "#, + ) + .await?; + + copy.send("id\n1\n2\n".as_bytes()).await?; + let rows = copy.finish().await?; + assert_eq!(rows, 2); + + // conn is safe for reuse + let value = sqlx::query("select 1 + 1") + .try_map(|row: PgRow| row.try_get::(0)) + .fetch_one(&mut conn) + .await?; + + assert_eq!(2i32, value); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_abort_copy_in() -> anyhow::Result<()> { + let mut conn = new::().await?; + conn.execute( + r#" + CREATE TEMPORARY TABLE users (id INTEGER NOT NULL); + "#, + ) + .await?; + + let mut copy = conn + .copy_in_raw( + r#" + COPY users (id) FROM STDIN WITH (FORMAT CSV, HEADER); + "#, + ) + .await?; + + copy.abort("this is only a test").await?; + + // conn is safe for reuse + let value = sqlx::query("select 1 + 1") + .try_map(|row: PgRow| row.try_get::(0)) + .fetch_one(&mut conn) + .await?; + + assert_eq!(2i32, value); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_copy_out() -> anyhow::Result<()> { + let mut conn = new::().await?; + + { + let mut copy = conn + .copy_out_raw( + " + COPY (SELECT generate_series(1, 2) AS id) TO STDOUT WITH (FORMAT CSV, HEADER); + ", + ) + .await?; + + assert_eq!(copy.next().await.unwrap().unwrap(), "id\n"); + assert_eq!(copy.next().await.unwrap().unwrap(), "1\n"); + assert_eq!(copy.next().await.unwrap().unwrap(), "2\n"); + if copy.next().await.is_some() { + anyhow::bail!("Unexpected data from COPY"); + } + } + + // conn is safe for reuse + let value = sqlx::query("select 1 + 1") + .try_map(|row: PgRow| row.try_get::(0)) + .fetch_one(&mut conn) + .await?; + + assert_eq!(2i32, value); + + Ok(()) +} async fn test_issue_1254() -> anyhow::Result<()> { #[derive(sqlx::Type)] diff --git a/tests/sqlite/.gitignore b/tests/sqlite/.gitignore new file mode 100644 index 0000000000..02a6711c35 --- /dev/null +++ b/tests/sqlite/.gitignore @@ -0,0 +1,2 @@ +sqlite.db +