diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs index 012714d710..9cb0690bda 100644 --- a/sqlx-core/src/mysql/connection/executor.rs +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -4,7 +4,7 @@ use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::ext::ustr::UStr; use crate::logger::QueryLogger; -use crate::mysql::connection::stream::Busy; +use crate::mysql::connection::stream::Waiting; use crate::mysql::io::MySqlBufExt; use crate::mysql::protocol::response::Status; use crate::mysql::protocol::statement::{ @@ -93,7 +93,7 @@ impl MySqlConnection { let mut logger = QueryLogger::new(sql, self.log_settings.clone()); self.stream.wait_until_ready().await?; - self.stream.busy = Busy::Result; + self.stream.waiting.push_back(Waiting::Result); Ok(Box::pin(try_stream! { // make a slot for the shared column data @@ -146,12 +146,12 @@ impl MySqlConnection { continue; } - self.stream.busy = Busy::NotBusy; + self.stream.waiting.pop_front(); return Ok(()); } // otherwise, this first packet is the start of the result-set metadata, - self.stream.busy = Busy::Row; + *self.stream.waiting.front_mut().unwrap() = Waiting::Row; let num_columns = packet.get_uint_lenenc() as usize; // column count @@ -179,11 +179,11 @@ impl MySqlConnection { if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { // more result sets exist, continue to the next one - self.stream.busy = Busy::Result; + *self.stream.waiting.front_mut().unwrap() = Waiting::Result; break; } - self.stream.busy = Busy::NotBusy; + self.stream.waiting.pop_front(); return Ok(()); } diff --git a/sqlx-core/src/mysql/connection/mod.rs b/sqlx-core/src/mysql/connection/mod.rs index 509426a63d..4ade06beeb 100644 --- a/sqlx-core/src/mysql/connection/mod.rs +++ b/sqlx-core/src/mysql/connection/mod.rs @@ -16,7 +16,7 @@ mod executor; mod stream; mod tls; -pub(crate) use stream::{Busy, MySqlStream}; +pub(crate) use stream::{MySqlStream, Waiting}; const MAX_PACKET_SIZE: u32 = 1024; diff --git a/sqlx-core/src/mysql/connection/stream.rs b/sqlx-core/src/mysql/connection/stream.rs index 8b2f453608..e43cf253c6 100644 --- a/sqlx-core/src/mysql/connection/stream.rs +++ b/sqlx-core/src/mysql/connection/stream.rs @@ -1,3 +1,4 @@ +use std::collections::VecDeque; use std::ops::{Deref, DerefMut}; use bytes::{Buf, Bytes}; @@ -16,15 +17,13 @@ pub struct MySqlStream { pub(crate) server_version: (u16, u16, u16), pub(super) capabilities: Capabilities, pub(crate) sequence_id: u8, - pub(crate) busy: Busy, + pub(crate) waiting: VecDeque, pub(crate) charset: CharSet, pub(crate) collation: Collation, } #[derive(Debug, PartialEq, Eq)] -pub(crate) enum Busy { - NotBusy, - +pub(crate) enum Waiting { // waiting for a result set Result, @@ -65,7 +64,7 @@ impl MySqlStream { } Ok(Self { - busy: Busy::NotBusy, + waiting: VecDeque::new(), capabilities, server_version: (0, 0, 0), sequence_id: 0, @@ -80,32 +79,32 @@ impl MySqlStream { self.stream.flush().await?; } - while self.busy != Busy::NotBusy { - while self.busy == Busy::Row { + while !self.waiting.is_empty() { + while self.waiting.front() == Some(&Waiting::Row) { let packet = self.recv_packet().await?; if packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.capabilities)?; - self.busy = if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { - Busy::Result + if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { + *self.waiting.front_mut().unwrap() = Waiting::Result; } else { - Busy::NotBusy + self.waiting.pop_front(); }; } } - while self.busy == Busy::Result { + while self.waiting.front() == Some(&Waiting::Result) { let packet = self.recv_packet().await?; if packet[0] == 0x00 || packet[0] == 0xff { let ok = packet.ok()?; if !ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { - self.busy = Busy::NotBusy; + self.waiting.pop_front(); } } else { - self.busy = Busy::Row; + *self.waiting.front_mut().unwrap() = Waiting::Row; self.skip_result_metadata(packet).await?; } } @@ -150,7 +149,7 @@ impl MySqlStream { // TODO: packet joining if payload[0] == 0xff { - self.busy = Busy::NotBusy; + self.waiting.pop_front(); // instead of letting this packet be looked at everywhere, we check here // and emit a proper Error diff --git a/sqlx-core/src/mysql/transaction.rs b/sqlx-core/src/mysql/transaction.rs index b62fc143b5..97cb121d0e 100644 --- a/sqlx-core/src/mysql/transaction.rs +++ b/sqlx-core/src/mysql/transaction.rs @@ -2,7 +2,7 @@ use futures_core::future::BoxFuture; use crate::error::Error; use crate::executor::Executor; -use crate::mysql::connection::Busy; +use crate::mysql::connection::Waiting; use crate::mysql::protocol::text::Query; use crate::mysql::{MySql, MySqlConnection}; use crate::transaction::{ @@ -57,7 +57,7 @@ impl TransactionManager for MySqlTransactionManager { let depth = conn.transaction_depth; if depth > 0 { - conn.stream.busy = Busy::Result; + conn.stream.waiting.push_back(Waiting::Result); conn.stream.sequence_id = 0; conn.stream .write_packet(Query(&*rollback_ansi_transaction_sql(depth))); diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index baeaf9923a..d78009b4e2 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -387,3 +387,62 @@ async fn test_issue_622() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_can_work_with_transactions() -> anyhow::Result<()> { + let mut conn = new::().await?; + conn.execute("CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY);") + .await?; + + // begin .. rollback + + let mut tx = conn.begin().await?; + sqlx::query("INSERT INTO users (id) VALUES (?)") + .bind(1_i32) + .execute(&mut tx) + .await?; + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&mut tx) + .await?; + assert_eq!(count, 1); + tx.rollback().await?; + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&mut conn) + .await?; + assert_eq!(count, 0); + + // begin .. commit + + let mut tx = conn.begin().await?; + sqlx::query("INSERT INTO users (id) VALUES (?)") + .bind(1_i32) + .execute(&mut tx) + .await?; + tx.commit().await?; + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&mut conn) + .await?; + assert_eq!(count, 1); + + // begin .. (drop) + + { + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO users (id) VALUES (?)") + .bind(2) + .execute(&mut tx) + .await?; + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&mut tx) + .await?; + assert_eq!(count, 2); + // tx is dropped + } + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&mut conn) + .await?; + assert_eq!(count, 1); + + Ok(()) +}