From 6f8c3571d683f2187d40ecb67ca9fd18b627ee6b Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 15 Nov 2021 20:03:02 -0800 Subject: [PATCH] refactor(sqlite): make background thread responsible for all FFI calls --- Cargo.lock | 114 ++++-- Cargo.toml | 4 +- sqlx-core/Cargo.toml | 5 +- sqlx-core/src/common/mod.rs | 25 ++ sqlx-core/src/error.rs | 2 +- sqlx-core/src/sqlite/arguments.rs | 27 +- sqlx-core/src/sqlite/connection/collation.rs | 127 ++++-- sqlx-core/src/sqlite/connection/describe.rs | 137 ++++--- sqlx-core/src/sqlite/connection/establish.rs | 166 ++++---- sqlx-core/src/sqlite/connection/execute.rs | 117 ++++++ sqlx-core/src/sqlite/connection/executor.rs | 241 ++---------- sqlx-core/src/sqlite/connection/explain.rs | 18 +- sqlx-core/src/sqlite/connection/handle.rs | 79 ++-- sqlx-core/src/sqlite/connection/mod.rs | 204 ++++++++-- sqlx-core/src/sqlite/connection/worker.rs | 388 +++++++++++++++++++ sqlx-core/src/sqlite/mod.rs | 30 +- sqlx-core/src/sqlite/options/connect.rs | 13 +- sqlx-core/src/sqlite/options/mod.rs | 97 +++++ sqlx-core/src/sqlite/row.rs | 97 +---- sqlx-core/src/sqlite/statement/handle.rs | 62 ++- sqlx-core/src/sqlite/statement/mod.rs | 4 +- sqlx-core/src/sqlite/statement/virtual.rs | 200 +++++----- sqlx-core/src/sqlite/statement/worker.rs | 161 -------- sqlx-core/src/sqlite/transaction.rs | 72 +--- sqlx-core/src/sqlite/value.rs | 58 --- tests/sqlite/sqlite.rs | 74 +++- 26 files changed, 1490 insertions(+), 1032 deletions(-) create mode 100644 sqlx-core/src/sqlite/connection/execute.rs create mode 100644 sqlx-core/src/sqlite/connection/worker.rs delete mode 100644 sqlx-core/src/sqlite/statement/worker.rs diff --git a/Cargo.lock b/Cargo.lock index 4c2ef3d87a..fcac67101e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -846,6 +846,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "flume" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24c3fd473b3a903a62609e413ed7538f99e10b665ecb502b5e481a95283f8ab4" +dependencies = [ + "futures-core", + "futures-sink", + "pin-project", + "spin 0.9.2", +] + [[package]] name = "foreign-types" version = "0.3.2" @@ -900,9 +912,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.15" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e682a68b29a882df0545c143dc3646daefe80ba479bcdede94d5a703de2871e2" +checksum = "ba3dda0b6588335f360afc675d0564c17a77a2bda81ca178a4b6081bd86c7f0b" dependencies = [ "futures-core", "futures-sink", @@ -910,15 +922,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.15" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1" +checksum = "d0c8ff0461b82559810cdccfde3215c3f373807f5e5232b71479bff7bb2583d7" [[package]] name = "futures-executor" -version = "0.3.15" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "badaa6a909fac9e7236d0620a2f57f7664640c56575b71a7552fbd68deafab79" +checksum = "29d6d2ff5bb10fb95c85b8ce46538a2e5f5e7fdc755623a7d4529ab8a4ed9d2a" dependencies = [ "futures-core", "futures-task", @@ -938,9 +950,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.15" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acc499defb3b348f8d8f3f66415835a9131856ff7714bf10dadfc4ec4bdb29a1" +checksum = "b1f9d34af5a1aac6fb380f735fe510746c38067c5bf16c7fd250280503c971b2" [[package]] name = "futures-lite" @@ -959,12 +971,10 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.15" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4c40298486cdf52cc00cd6d6987892ba502c7656a16a4192a9992b1ccedd121" +checksum = "6dbd947adfffb0efc70599b3ddcf7b5597bb5fa9e245eb99f62b3a5f7bb8bd3c" dependencies = [ - "autocfg 1.0.1", - "proc-macro-hack", "proc-macro2", "quote", "syn", @@ -972,23 +982,22 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.15" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a57bead0ceff0d6dde8f465ecd96c9338121bb7717d3e7b108059531870c4282" +checksum = "e3055baccb68d74ff6480350f8d6eb8fcfa3aa11bdc1a1ae3afdd0514617d508" [[package]] name = "futures-task" -version = "0.3.15" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a16bef9fc1a4dddb5bee51c989e3fbba26569cbb0e31f5b303c184e3dd33dae" +checksum = "6ee7c6485c30167ce4dfb83ac568a849fe53274c831081476ee13e0dce1aad72" [[package]] name = "futures-util" -version = "0.3.15" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feb5c238d27e2bf94ffdfd27b2c29e3df4a68c4193bb6427384259e2bf191967" +checksum = "d9b5cf40b47a271f77a8b1bec03ca09044d99d2372c0de244e66430761127164" dependencies = [ - "autocfg 1.0.1", "futures-channel", "futures-core", "futures-io", @@ -998,8 +1007,6 @@ dependencies = [ "memchr", "pin-project-lite", "pin-utils", - "proc-macro-hack", - "proc-macro-nested", "slab", ] @@ -1243,7 +1250,7 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" dependencies = [ - "spin", + "spin 0.5.2", ] [[package]] @@ -1717,6 +1724,26 @@ dependencies = [ "ucd-trie", ] +[[package]] +name = "pin-project" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1622113ce508488160cff04e6abc60960e676d330e1ca0f77c0b8df17c81438f" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95af56fee93df76d721d356ac1ca41fccf168bc448eb14049234df764ba3e76" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.6" @@ -1865,17 +1892,11 @@ version = "0.5.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" -[[package]] -name = "proc-macro-nested" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" - [[package]] name = "proc-macro2" -version = "1.0.28" +version = "1.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c7ed8b8c7b886ea3ed7dde405212185f423ab44682667c8c6dd14aa1d9f6612" +checksum = "c7342d5883fbccae1cc37a2353b09c87c9b0f3afd73f5fb9bba687a1f733b029" dependencies = [ "unicode-xid", ] @@ -1912,9 +1933,9 @@ checksum = "941ba9d78d8e2f7ce474c015eea4d9c6d25b6a3327f9832ee29a4de27f91bbb8" [[package]] name = "rand" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ef9e7e66b4468674bfcb0c81af8b7fa0bb154fa9f28eb840da5c447baeb8d7e" +checksum = "2e7573632e6454cf6b99d7aac4ccca54be06da05aca2ef7423d22d27d4d4bcd8" dependencies = [ "libc", "rand_chacha", @@ -1950,6 +1971,15 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core", +] + [[package]] name = "rayon" version = "1.5.1" @@ -2051,7 +2081,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", + "spin 0.5.2", "untrusted", "web-sys", "winapi", @@ -2356,6 +2386,15 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "511254be0c5bcf062b019a6c89c01a664aa359ded62f78aa72c6fc137c0590e5" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.4.1" @@ -2387,7 +2426,10 @@ dependencies = [ "dotenv", "env_logger 0.8.3", "futures", + "hex", "paste", + "rand", + "rand_xoshiro", "serde", "serde_json", "sqlx-core", @@ -2456,8 +2498,10 @@ dependencies = [ "dirs", "either", "encoding_rs", + "flume", "futures-channel", "futures-core", + "futures-executor", "futures-intrusive", "futures-util", "generic-array", @@ -2737,9 +2781,9 @@ checksum = "1e81da0851ada1f3e9d4312c704aa4f8806f0f9d69faaf8df2f3464b4a9437c2" [[package]] name = "syn" -version = "1.0.74" +version = "1.0.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1873d832550d4588c3dbc20f01361ab00bfe741048f71e3fecf145a7cc18b29c" +checksum = "ecb2e6da8ee5eb9a61068762a32fa9619cc591ceb055b3687f4cd4051ec2e06b" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index aab4f3b613..31dd4380dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -150,7 +150,9 @@ paste = "1.0.1" serde = { version = "1.0.111", features = ["derive"] } serde_json = "1.0.53" url = "2.1.1" - +rand = "0.8.4" +rand_xoshiro = "0.6.0" +hex = "0.4" # # Any # diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index ee01a46c64..e119aa0db5 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -43,7 +43,7 @@ mysql = [ "rand", "rsa", ] -sqlite = ["libsqlite3-sys"] +sqlite = ["libsqlite3-sys", "futures-executor", "flume"] mssql = ["uuid", "encoding_rs", "regex"] any = [] @@ -122,6 +122,9 @@ futures-channel = { version = "0.3.5", default-features = false, features = ["si futures-core = { version = "0.3.5", default-features = false } futures-intrusive = "0.4.0" futures-util = { version = "0.3.5", default-features = false, features = ["alloc", "sink"] } +# used by the SQLite worker thread to block on the async mutex that locks the database handle +futures-executor = { version = "0.3.17", optional = true } +flume = { version = "0.10.9", optional = true, default-features = false, features = ["async"] } generic-array = { version = "0.14.4", default-features = false, optional = true } hex = "0.4.2" hmac = { version = "0.11.0", default-features = false, optional = true } diff --git a/sqlx-core/src/common/mod.rs b/sqlx-core/src/common/mod.rs index f9698f28c2..63ed52815b 100644 --- a/sqlx-core/src/common/mod.rs +++ b/sqlx-core/src/common/mod.rs @@ -1,3 +1,28 @@ mod statement_cache; pub(crate) use statement_cache::StatementCache; +use std::fmt::{Debug, Formatter}; +use std::ops::{Deref, DerefMut}; + +/// A wrapper for `Fn`s that provides a debug impl that just says "Function" +pub(crate) struct DebugFn(pub F); + +impl Deref for DebugFn { + type Target = F; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DebugFn { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Debug for DebugFn { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("Function").finish() + } +} diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 6a152520db..0659375846 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -39,7 +39,7 @@ pub enum Error { Database(#[source] Box), /// Error communicating with the database backend. - #[error("error communicating with the server: {0}")] + #[error("error communicating with database: {0}")] Io(#[from] io::Error), /// Error occurred while attempting to establish a TLS connection. diff --git a/sqlx-core/src/sqlite/arguments.rs b/sqlx-core/src/sqlite/arguments.rs index dc899b8165..17b3b90f54 100644 --- a/sqlx-core/src/sqlite/arguments.rs +++ b/sqlx-core/src/sqlite/arguments.rs @@ -31,6 +31,16 @@ impl<'q> SqliteArguments<'q> { self.values.push(SqliteArgumentValue::Null); } } + + pub(crate) fn into_static(self) -> SqliteArguments<'static> { + SqliteArguments { + values: self + .values + .into_iter() + .map(SqliteArgumentValue::into_static) + .collect(), + } + } } impl<'q> Arguments<'q> for SqliteArguments<'q> { @@ -49,7 +59,7 @@ impl<'q> Arguments<'q> for SqliteArguments<'q> { } impl SqliteArguments<'_> { - pub(super) fn bind(&self, handle: &StatementHandle, offset: usize) -> Result { + pub(super) fn bind(&self, handle: &mut StatementHandle, offset: usize) -> Result { let mut arg_i = offset; // for handle in &statement.handles { @@ -95,7 +105,20 @@ impl SqliteArguments<'_> { } impl SqliteArgumentValue<'_> { - fn bind(&self, handle: &StatementHandle, i: usize) -> Result<(), Error> { + fn into_static(self) -> SqliteArgumentValue<'static> { + use SqliteArgumentValue::*; + + match self { + Null => Null, + Text(text) => Text(text.into_owned().into()), + Blob(blob) => Blob(blob.into_owned().into()), + Int(v) => Int(v), + Int64(v) => Int64(v), + Double(v) => Double(v), + } + } + + fn bind(&self, handle: &mut StatementHandle, i: usize) -> Result<(), Error> { use SqliteArgumentValue::*; let status = match self { diff --git a/sqlx-core/src/sqlite/connection/collation.rs b/sqlx-core/src/sqlite/connection/collation.rs index 8c58352592..ea9a50c40f 100644 --- a/sqlx-core/src/sqlite/connection/collation.rs +++ b/sqlx-core/src/sqlite/connection/collation.rs @@ -1,8 +1,10 @@ use std::cmp::Ordering; use std::ffi::CString; +use std::fmt::{self, Debug, Formatter}; use std::os::raw::{c_int, c_void}; use std::slice; use std::str::from_utf8_unchecked; +use std::sync::Arc; use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8}; @@ -10,46 +12,84 @@ use crate::error::Error; use crate::sqlite::connection::handle::ConnectionHandle; use crate::sqlite::SqliteError; -unsafe extern "C" fn free_boxed_value(p: *mut c_void) { - drop(Box::from_raw(p as *mut T)); -} - -pub(crate) fn create_collation( - handle: &ConnectionHandle, - name: &str, - compare: F, -) -> Result<(), Error> -where - F: Fn(&str, &str) -> Ordering + Send + Sync + 'static, -{ - unsafe extern "C" fn call_boxed_closure( +#[derive(Clone)] +pub struct Collation { + name: Arc, + collate: Arc Ordering + Send + Sync + 'static>, + // SAFETY: these must match the concrete type of `collate` + call: unsafe extern "C" fn( arg1: *mut c_void, arg2: c_int, arg3: *const c_void, arg4: c_int, arg5: *const c_void, - ) -> c_int + ) -> c_int, + free: unsafe extern "C" fn(*mut c_void), +} + +impl Collation { + pub fn new(name: N, collate: F) -> Self where - C: Fn(&str, &str) -> Ordering, + N: Into>, + F: Fn(&str, &str) -> Ordering + Send + Sync + 'static, { - let boxed_f: *mut C = arg1 as *mut C; - debug_assert!(!boxed_f.is_null()); - let s1 = { - let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize); - from_utf8_unchecked(c_slice) - }; - let s2 = { - let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize); - from_utf8_unchecked(c_slice) + unsafe extern "C" fn drop_arc_value(p: *mut c_void) { + drop(Arc::from_raw(p as *mut T)); + } + + Collation { + name: name.into(), + collate: Arc::new(collate), + call: call_boxed_closure::, + free: drop_arc_value::, + } + } + + pub(crate) fn create(&self, handle: &mut ConnectionHandle) -> Result<(), Error> { + let raw_f = Arc::into_raw(Arc::clone(&self.collate)); + let c_name = CString::new(&*self.name) + .map_err(|_| err_protocol!("invalid collation name: {:?}", self.name))?; + let flags = SQLITE_UTF8; + let r = unsafe { + sqlite3_create_collation_v2( + handle.as_ptr(), + c_name.as_ptr(), + flags, + raw_f as *mut c_void, + Some(self.call), + Some(self.free), + ) }; - let t = (*boxed_f)(s1, s2); - match t { - Ordering::Less => -1, - Ordering::Equal => 0, - Ordering::Greater => 1, + if r == SQLITE_OK { + Ok(()) + } else { + // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails. + drop(unsafe { Arc::from_raw(raw_f) }); + Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) } } +} + +impl Debug for Collation { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Collation") + .field("name", &self.name) + .finish_non_exhaustive() + } +} + +pub(crate) fn create_collation( + handle: &mut ConnectionHandle, + name: &str, + compare: F, +) -> Result<(), Error> +where + F: Fn(&str, &str) -> Ordering + Send + Sync + 'static, +{ + unsafe extern "C" fn free_boxed_value(p: *mut c_void) { + drop(Box::from_raw(p as *mut T)); + } let boxed_f: *mut F = Box::into_raw(Box::new(compare)); let c_name = @@ -74,3 +114,32 @@ where Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) } } + +unsafe extern "C" fn call_boxed_closure( + data: *mut c_void, + left_len: c_int, + left_ptr: *const c_void, + right_len: c_int, + right_ptr: *const c_void, +) -> c_int +where + C: Fn(&str, &str) -> Ordering, +{ + let boxed_f: *mut C = data as *mut C; + debug_assert!(!boxed_f.is_null()); + let s1 = { + let c_slice = slice::from_raw_parts(left_ptr as *const u8, left_len as usize); + from_utf8_unchecked(c_slice) + }; + let s2 = { + let c_slice = slice::from_raw_parts(right_ptr as *const u8, right_len as usize); + from_utf8_unchecked(c_slice) + }; + let t = (*boxed_f)(s1, s2); + + match t { + Ordering::Less => -1, + Ordering::Equal => 0, + Ordering::Greater => 1, + } +} diff --git a/sqlx-core/src/sqlite/connection/describe.rs b/sqlx-core/src/sqlite/connection/describe.rs index cb86e7e024..37f66975cf 100644 --- a/sqlx-core/src/sqlite/connection/describe.rs +++ b/sqlx-core/src/sqlite/connection/describe.rs @@ -1,101 +1,94 @@ use crate::describe::Describe; use crate::error::Error; use crate::sqlite::connection::explain::explain; +use crate::sqlite::connection::ConnectionState; use crate::sqlite::statement::VirtualStatement; use crate::sqlite::type_info::DataType; -use crate::sqlite::{Sqlite, SqliteColumn, SqliteConnection}; +use crate::sqlite::{Sqlite, SqliteColumn}; use either::Either; -use futures_core::future::BoxFuture; use std::convert::identity; -pub(super) fn describe<'c: 'e, 'q: 'e, 'e>( - conn: &'c mut SqliteConnection, - query: &'q str, -) -> BoxFuture<'e, Result, Error>> { - Box::pin(async move { - // describing a statement from SQLite can be involved - // each SQLx statement is comprised of multiple SQL statements +pub(super) fn describe(conn: &mut ConnectionState, query: &str) -> Result, Error> { + // describing a statement from SQLite can be involved + // each SQLx statement is comprised of multiple SQL statements - let statement = VirtualStatement::new(query, false); + let mut statement = VirtualStatement::new(query, false)?; - let mut columns = Vec::new(); - let mut nullable = Vec::new(); - let mut num_params = 0; + let mut columns = Vec::new(); + let mut nullable = Vec::new(); + let mut num_params = 0; - let mut statement = statement?; + // we start by finding the first statement that *can* return results + while let Some(stmt) = statement.prepare_next(&mut conn.handle)? { + num_params += stmt.handle.bind_parameter_count(); - // we start by finding the first statement that *can* return results - while let Some((stmt, ..)) = statement.prepare(&mut conn.handle)? { - num_params += stmt.bind_parameter_count(); + let mut stepped = false; - let mut stepped = false; - - let num = stmt.column_count(); - if num == 0 { - // no columns in this statement; skip - continue; - } + let num = stmt.handle.column_count(); + if num == 0 { + // no columns in this statement; skip + continue; + } - // next we try to use [column_decltype] to inspect the type of each column - columns.reserve(num); + // next we try to use [column_decltype] to inspect the type of each column + columns.reserve(num); - // as a last resort, we explain the original query and attempt to - // infer what would the expression types be as a fallback - // to [column_decltype] + // as a last resort, we explain the original query and attempt to + // infer what would the expression types be as a fallback + // to [column_decltype] - // if explain.. fails, ignore the failure and we'll have no fallback - let (fallback, fallback_nullable) = match explain(conn, stmt.sql()).await { - Ok(v) => v, - Err(err) => { - log::debug!("describe: explain introspection failed: {}", err); + // if explain.. fails, ignore the failure and we'll have no fallback + let (fallback, fallback_nullable) = match explain(conn, stmt.handle.sql()) { + Ok(v) => v, + Err(err) => { + log::debug!("describe: explain introspection failed: {}", err); - (vec![], vec![]) + (vec![], vec![]) + } + }; + + for col in 0..num { + let name = stmt.handle.column_name(col).to_owned(); + + let type_info = if let Some(ty) = stmt.handle.column_decltype(col) { + ty + } else { + // if that fails, we back up and attempt to step the statement + // once *if* its read-only and then use [column_type] as a + // fallback to [column_decltype] + if !stepped && stmt.handle.read_only() { + stepped = true; + let _ = stmt.handle.step(); } - }; - for col in 0..num { - let name = stmt.column_name(col).to_owned(); - - let type_info = if let Some(ty) = stmt.column_decltype(col) { - ty - } else { - // if that fails, we back up and attempt to step the statement - // once *if* its read-only and then use [column_type] as a - // fallback to [column_decltype] - if !stepped && stmt.read_only() { - stepped = true; - let _ = conn.worker.step(stmt).await; - } - - let mut ty = stmt.column_type_info(col); + let mut ty = stmt.handle.column_type_info(col); - if ty.0 == DataType::Null { - if let Some(fallback) = fallback.get(col).cloned() { - ty = fallback; - } + if ty.0 == DataType::Null { + if let Some(fallback) = fallback.get(col).cloned() { + ty = fallback; } + } - ty - }; + ty + }; - // check explain - let col_nullable = stmt.column_nullable(col)?; - let exp_nullable = fallback_nullable.get(col).copied().and_then(identity); + // check explain + let col_nullable = stmt.handle.column_nullable(col)?; + let exp_nullable = fallback_nullable.get(col).copied().and_then(identity); - nullable.push(exp_nullable.or(col_nullable)); + nullable.push(exp_nullable.or(col_nullable)); - columns.push(SqliteColumn { - name: name.into(), - type_info, - ordinal: col, - }); - } + columns.push(SqliteColumn { + name: name.into(), + type_info, + ordinal: col, + }); } + } - Ok(Describe { - columns, - parameters: Some(Either::Right(num_params)), - nullable, - }) + Ok(Describe { + columns, + parameters: Some(Either::Right(num_params)), + nullable, }) } diff --git a/sqlx-core/src/sqlite/connection/establish.rs b/sqlx-core/src/sqlite/connection/establish.rs index ce8105a652..ccd913fa8c 100644 --- a/sqlx-core/src/sqlite/connection/establish.rs +++ b/sqlx-core/src/sqlite/connection/establish.rs @@ -1,87 +1,112 @@ +use crate::connection::LogSettings; use crate::error::Error; use crate::sqlite::connection::handle::ConnectionHandle; -use crate::sqlite::statement::StatementWorker; -use crate::{ - common::StatementCache, - sqlite::{SqliteConnectOptions, SqliteConnection, SqliteError}, -}; +use crate::sqlite::connection::{ConnectionState, Statements}; +use crate::sqlite::{SqliteConnectOptions, SqliteError}; use libsqlite3_sys::{ sqlite3_busy_timeout, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK, SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX, SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE, }; -use sqlx_rt::blocking; +use std::ffi::CString; use std::io; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; use std::{ convert::TryFrom, ptr::{null, null_mut}, }; -pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result { - let mut filename = options - .filename - .to_str() - .ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidData, - "filename passed to SQLite must be valid UTF-8", - ) - })? - .to_owned(); - - // By default, we connect to an in-memory database. - // [SQLITE_OPEN_NOMUTEX] will instruct [sqlite3_open_v2] to return an error if it - // cannot satisfy our wish for a thread-safe, lock-free connection object - - let mut flags = if options.serialized { - SQLITE_OPEN_FULLMUTEX - } else { - SQLITE_OPEN_NOMUTEX - }; - - flags |= if options.read_only { - SQLITE_OPEN_READONLY - } else if options.create_if_missing { - SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE - } else { - SQLITE_OPEN_READWRITE - }; - - if options.in_memory { - flags |= SQLITE_OPEN_MEMORY; - } +static THREAD_ID: AtomicU64 = AtomicU64::new(0); - flags |= if options.shared_cache { - SQLITE_OPEN_SHAREDCACHE - } else { - SQLITE_OPEN_PRIVATECACHE - }; +pub struct EstablishParams { + filename: CString, + open_flags: i32, + busy_timeout: Duration, + statement_cache_capacity: usize, + log_settings: LogSettings, + pub(crate) thread_name: String, + pub(crate) command_channel_size: usize, +} - if options.immutable { - filename = format!("file:{}?immutable=true", filename); - flags |= libsqlite3_sys::SQLITE_OPEN_URI; - } +impl EstablishParams { + pub fn from_options(options: &SqliteConnectOptions) -> Result { + let mut filename = options + .filename + .to_str() + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "filename passed to SQLite must be valid UTF-8", + ) + })? + .to_owned(); + + // By default, we connect to an in-memory database. + // [SQLITE_OPEN_NOMUTEX] will instruct [sqlite3_open_v2] to return an error if it + // cannot satisfy our wish for a thread-safe, lock-free connection object + + let mut flags = if options.serialized { + SQLITE_OPEN_FULLMUTEX + } else { + SQLITE_OPEN_NOMUTEX + }; - filename.push('\0'); + flags |= if options.read_only { + SQLITE_OPEN_READONLY + } else if options.create_if_missing { + SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE + } else { + SQLITE_OPEN_READWRITE + }; - let busy_timeout = options.busy_timeout; + if options.in_memory { + flags |= SQLITE_OPEN_MEMORY; + } - let handle = blocking!({ + flags |= if options.shared_cache { + SQLITE_OPEN_SHAREDCACHE + } else { + SQLITE_OPEN_PRIVATECACHE + }; + + if options.immutable { + filename = format!("file:{}?immutable=true", filename); + flags |= libsqlite3_sys::SQLITE_OPEN_URI; + } + + let filename = CString::new(filename).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + "filename passed to SQLite must not contain nul bytes", + ) + })?; + + Ok(Self { + filename, + open_flags: flags, + busy_timeout: options.busy_timeout, + statement_cache_capacity: options.statement_cache_capacity, + log_settings: options.log_settings.clone(), + thread_name: (options.thread_name)(THREAD_ID.fetch_add(1, Ordering::AcqRel)), + command_channel_size: options.command_channel_size, + }) + } + + pub(crate) fn establish(&self) -> Result { let mut handle = null_mut(); // let mut status = unsafe { - sqlite3_open_v2( - filename.as_bytes().as_ptr() as *const _, - &mut handle, - flags, - null(), - ) + sqlite3_open_v2(self.filename.as_ptr(), &mut handle, self.open_flags, null()) }; if handle.is_null() { // Failed to allocate memory - panic!("SQLite is unable to allocate memory to hold the sqlite3 object"); + return Err(Error::Io(io::Error::new( + io::ErrorKind::OutOfMemory, + "SQLite is unable to allocate memory to hold the sqlite3 object", + ))); } // SAFE: tested for NULL just above @@ -101,12 +126,11 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result Result { + handle: &'a mut ConnectionHandle, + statement: &'a mut VirtualStatement, + logger: QueryLogger<'a>, + args: Option>, + + /// since a `VirtualStatement` can encompass multiple actual statements, + /// this keeps track of the number of arguments so far + args_used: usize, + + goto_next: bool, +} + +pub(crate) fn iter<'a>( + conn: &'a mut ConnectionState, + query: &'a str, + args: Option>, + persistent: bool, +) -> Result, Error> { + // fetch the cached statement or allocate a new one + let statement = conn.statements.get(query, persistent)?; + + let logger = QueryLogger::new(query, conn.log_settings.clone()); + + Ok(ExecuteIter { + handle: &mut conn.handle, + statement, + logger, + args, + args_used: 0, + goto_next: true, + }) +} + +fn bind( + statement: &mut StatementHandle, + arguments: &Option>, + offset: usize, +) -> Result { + let mut n = 0; + + if let Some(arguments) = arguments { + n = arguments.bind(statement, offset)?; + } + + Ok(n) +} + +impl Iterator for ExecuteIter<'_> { + type Item = Result, Error>; + + fn next(&mut self) -> Option { + let statement = if self.goto_next { + let mut statement = match self.statement.prepare_next(self.handle) { + Ok(Some(statement)) => statement, + Ok(None) => return None, + Err(e) => return Some(Err(e.into())), + }; + + self.goto_next = false; + + // sanity check: ensure the VM is reset and the bindings are cleared + if let Err(e) = statement.handle.reset() { + return Some(Err(e.into())); + } + + statement.handle.clear_bindings(); + + match bind(&mut statement.handle, &self.args, self.args_used) { + Ok(args_used) => self.args_used += args_used, + Err(e) => return Some(Err(e)), + } + + statement + } else { + self.statement.current()? + }; + + match statement.handle.step() { + Ok(true) => { + self.logger.increment_rows(); + + Some(Ok(Either::Right(SqliteRow::current( + &statement.handle, + &statement.columns, + &statement.column_names, + )))) + } + Ok(false) => { + let last_insert_rowid = self.handle.last_insert_rowid(); + + let done = SqliteQueryResult { + changes: statement.handle.changes(), + last_insert_rowid, + }; + + self.goto_next = true; + + Some(Ok(Either::Left(done))) + } + Err(e) => Some(Err(e.into())), + } + } +} + +impl Drop for ExecuteIter<'_> { + fn drop(&mut self) { + self.statement.reset().ok(); + } +} diff --git a/sqlx-core/src/sqlite/connection/executor.rs b/sqlx-core/src/sqlite/connection/executor.rs index 5ff7557d1e..69ed29b92d 100644 --- a/sqlx-core/src/sqlite/connection/executor.rs +++ b/sqlx-core/src/sqlite/connection/executor.rs @@ -1,88 +1,13 @@ -use crate::common::StatementCache; use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; -use crate::logger::QueryLogger; -use crate::sqlite::connection::describe::describe; -use crate::sqlite::statement::{StatementHandle, StatementWorker, VirtualStatement}; use crate::sqlite::{ - Sqlite, SqliteArguments, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteStatement, - SqliteTypeInfo, + Sqlite, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteStatement, SqliteTypeInfo, }; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; -use libsqlite3_sys::sqlite3_last_insert_rowid; -use std::borrow::Cow; -use std::sync::Arc; - -async fn prepare<'a>( - worker: &mut StatementWorker, - statements: &'a mut StatementCache, - statement: &'a mut Option, - query: &str, - persistent: bool, -) -> Result<&'a mut VirtualStatement, Error> { - if !persistent || !statements.is_enabled() { - *statement = Some(VirtualStatement::new(query, false)?); - return Ok(statement.as_mut().unwrap()); - } - - let exists = statements.contains_key(query); - - if !exists { - let statement = VirtualStatement::new(query, true)?; - statements.insert(query, statement); - } - - let statement = statements.get_mut(query).unwrap(); - - if exists { - // as this statement has been executed before, we reset before continuing - // this also causes any rows that are from the statement to be inflated - statement.reset(worker).await?; - } - - Ok(statement) -} - -fn bind( - statement: &StatementHandle, - arguments: &Option>, - offset: usize, -) -> Result { - let mut n = 0; - - if let Some(arguments) = arguments { - n += arguments.bind(statement, offset)?; - } - - Ok(n) -} - -/// A structure holding sqlite statement handle and resetting the -/// statement when it is dropped. -struct StatementResetter<'a> { - handle: Arc, - worker: &'a mut StatementWorker, -} - -impl<'a> StatementResetter<'a> { - fn new(worker: &'a mut StatementWorker, handle: &Arc) -> Self { - Self { - worker, - handle: Arc::clone(handle), - } - } -} - -impl Drop for StatementResetter<'_> { - fn drop(&mut self) { - // this method is designed to eagerly send the reset command - // so we don't need to await or spawn it - let _ = self.worker.reset(&self.handle); - } -} +use futures_util::{TryFutureExt, TryStreamExt}; impl<'c> Executor<'c> for &'c mut SqliteConnection { type Database = Sqlite; @@ -96,80 +21,15 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { E: Execute<'q, Self::Database>, { let sql = query.sql(); - let mut logger = QueryLogger::new(sql, self.log_settings.clone()); let arguments = query.take_arguments(); let persistent = query.persistent() && arguments.is_some(); - Box::pin(try_stream! { - let SqliteConnection { - handle: ref mut conn, - ref mut statements, - ref mut statement, - ref mut worker, - .. - } = self; - - // prepare statement object (or checkout from cache) - let stmt = prepare(worker, statements, statement, sql, persistent).await?; - - // keep track of how many arguments we have bound - let mut num_arguments = 0; - - while let Some((stmt, columns, column_names, last_row_values)) = stmt.prepare(conn)? { - // Prepare to reset raw SQLite statement when the handle - // is dropped. `StatementResetter` will reliably reset the - // statement even if the stream returned from `fetch_many` - // is dropped early. - let resetter = StatementResetter::new(worker, stmt); - - // bind values to the statement - num_arguments += bind(stmt, &arguments, num_arguments)?; - - loop { - // save the rows from the _current_ position on the statement - // and send them to the still-live row object - SqliteRow::inflate_if_needed(stmt, &*columns, last_row_values.take()); - - // invoke [sqlite3_step] on the dedicated worker thread - // this will move us forward one row or finish the statement - let s = resetter.worker.step(stmt).await?; - - match s { - Either::Left(changes) => { - let last_insert_rowid = unsafe { - sqlite3_last_insert_rowid(conn.as_ptr()) - }; - - let done = SqliteQueryResult { - changes, - last_insert_rowid, - }; - - r#yield!(Either::Left(done)); - - break; - } - - Either::Right(()) => { - let (row, weak_values_ref) = SqliteRow::current( - stmt.to_ref(conn.to_ref()), - columns, - column_names - ); - - let v = Either::Right(row); - *last_row_values = Some(weak_values_ref); - - logger.increment_rows(); - - r#yield!(v); - } - } - } - } - - Ok(()) - }) + Box::pin( + self.worker + .execute(sql, arguments, self.row_channel_size, persistent) + .map_ok(flume::Receiver::into_stream) + .try_flatten_stream(), + ) } fn fetch_optional<'e, 'q: 'e, E: 'q>( @@ -181,56 +41,24 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { E: Execute<'q, Self::Database>, { let sql = query.sql(); - let mut logger = QueryLogger::new(sql, self.log_settings.clone()); let arguments = query.take_arguments(); let persistent = query.persistent() && arguments.is_some(); Box::pin(async move { - let SqliteConnection { - handle: ref mut conn, - ref mut statements, - ref mut statement, - ref mut worker, - .. - } = self; - - // prepare statement object (or checkout from cache) - let virtual_stmt = prepare(worker, statements, statement, sql, persistent).await?; + let stream = self + .worker + .execute(sql, arguments, self.row_channel_size, persistent) + .map_ok(flume::Receiver::into_stream) + .try_flatten_stream(); - // keep track of how many arguments we have bound - let mut num_arguments = 0; + futures_util::pin_mut!(stream); - while let Some((stmt, columns, column_names, last_row_values)) = - virtual_stmt.prepare(conn)? - { - // bind values to the statement - num_arguments += bind(stmt, &arguments, num_arguments)?; - - // save the rows from the _current_ position on the statement - // and send them to the still-live row object - SqliteRow::inflate_if_needed(stmt, &*columns, last_row_values.take()); - - // invoke [sqlite3_step] on the dedicated worker thread - // this will move us forward one row or finish the statement - match worker.step(stmt).await? { - Either::Left(_) => (), - - Either::Right(()) => { - let (row, weak_values_ref) = SqliteRow::current( - stmt.to_ref(self.handle.to_ref()), - columns, - column_names, - ); - - *last_row_values = Some(weak_values_ref); - - logger.increment_rows(); - - virtual_stmt.reset(worker).await?; - return Ok(Some(row)); - } + while let Some(res) = stream.try_next().await? { + if let Either::Right(row) = res { + return Ok(Some(row)); } } + Ok(None) }) } @@ -244,36 +72,11 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { 'c: 'e, { Box::pin(async move { - let SqliteConnection { - handle: ref mut conn, - ref mut statements, - ref mut statement, - ref mut worker, - .. - } = self; - - // prepare statement object (or checkout from cache) - let statement = prepare(worker, statements, statement, sql, true).await?; - - let mut parameters = 0; - let mut columns = None; - let mut column_names = None; - - while let Some((statement, columns_, column_names_, _)) = statement.prepare(conn)? { - parameters += statement.bind_parameter_count(); - - // the first non-empty statement is chosen as the statement we pull columns from - if !columns_.is_empty() && columns.is_none() { - columns = Some(Arc::clone(columns_)); - column_names = Some(Arc::clone(column_names_)); - } - } + let statement = self.worker.prepare(sql).await?; Ok(SqliteStatement { - sql: Cow::Borrowed(sql), - columns: columns.unwrap_or_default(), - column_names: column_names.unwrap_or_default(), - parameters, + sql: sql.into(), + ..statement }) }) } @@ -283,6 +86,6 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { where 'c: 'e, { - describe(self, sql) + Box::pin(self.worker.describe(sql)) } } diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index 14df95e6ac..6179125577 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -1,7 +1,8 @@ use crate::error::Error; -use crate::query_as::query_as; +use crate::from_row::FromRow; +use crate::sqlite::connection::{execute, ConnectionState}; use crate::sqlite::type_info::DataType; -use crate::sqlite::{SqliteConnection, SqliteTypeInfo}; +use crate::sqlite::SqliteTypeInfo; use crate::HashMap; use std::str::from_utf8; @@ -97,8 +98,8 @@ fn opcode_to_type(op: &str) -> DataType { } // Opcode Reference: https://sqlite.org/opcode.html -pub(super) async fn explain( - conn: &mut SqliteConnection, +pub(super) fn explain( + conn: &mut ConnectionState, query: &str, ) -> Result<(Vec, Vec>), Error> { // Registers @@ -111,10 +112,11 @@ pub(super) async fn explain( // Nullable columns let mut n = HashMap::::with_capacity(6); - let program = - query_as::<_, (i64, String, i64, i64, i64, Vec)>(&*format!("EXPLAIN {}", query)) - .fetch_all(&mut *conn) - .await?; + let program: Vec<(i64, String, i64, i64, i64, Vec)> = + execute::iter(conn, &format!("EXPLAIN {}", query), None, false)? + .filter_map(|res| res.map(|either| either.right()).transpose()) + .map(|row| FromRow::from_row(&row?)) + .collect::, Error>>()?; let mut program_i = 0; let program_size = program.len(); diff --git a/sqlx-core/src/sqlite/connection/handle.rs b/sqlx-core/src/sqlite/connection/handle.rs index c714fcc5f4..f293169afe 100644 --- a/sqlx-core/src/sqlite/connection/handle.rs +++ b/sqlx-core/src/sqlite/connection/handle.rs @@ -1,26 +1,20 @@ +use std::ffi::CString; +use std::ptr; use std::ptr::NonNull; -use libsqlite3_sys::{sqlite3, sqlite3_close, SQLITE_OK}; +use crate::error::Error; +use libsqlite3_sys::{sqlite3, sqlite3_close, sqlite3_exec, sqlite3_last_insert_rowid, SQLITE_OK}; use crate::sqlite::SqliteError; -use std::sync::Arc; /// Managed handle to the raw SQLite3 database handle. /// The database handle will be closed when this is dropped and no `ConnectionHandleRef`s exist. #[derive(Debug)] -pub(crate) struct ConnectionHandle(Arc); +pub(crate) struct ConnectionHandle(NonNull); -/// A wrapper around `ConnectionHandle` which only exists for a `StatementWorker` to own -/// which prevents the `sqlite3` handle from being finalized while it is running `sqlite3_step()` -/// or `sqlite3_reset()`. -/// -/// Note that this does *not* actually give access to the database handle! +/// A wrapper around `ConnectionHandle` which *does not* finalize the handle on-drop. #[derive(Clone, Debug)] -pub(crate) struct ConnectionHandleRef(Arc); - -// Wrapper for `*mut sqlite3` which finalizes the handle on-drop. -#[derive(Debug)] -struct HandleInner(NonNull); +pub(crate) struct ConnectionHandleRaw(NonNull); // A SQLite3 handle is safe to send between threads, provided not more than // one is accessing it at the same time. This is upheld as long as [SQLITE_CONFIG_MULTITHREAD] is @@ -33,32 +27,67 @@ struct HandleInner(NonNull); unsafe impl Send for ConnectionHandle {} -// SAFETY: `Arc` normally only implements `Send` where `T: Sync` because it allows -// concurrent access. -// -// However, in this case we're only using `Arc` to prevent the database handle from being -// finalized while the worker still holds a statement handle; `ConnectionHandleRef` thus -// should *not* actually provide access to the database handle. -unsafe impl Send for ConnectionHandleRef {} +// SAFETY: this type does nothing but provide access to the DB handle pointer. +unsafe impl Send for ConnectionHandleRaw {} impl ConnectionHandle { #[inline] pub(super) unsafe fn new(ptr: *mut sqlite3) -> Self { - Self(Arc::new(HandleInner(NonNull::new_unchecked(ptr)))) + Self(NonNull::new_unchecked(ptr)) } #[inline] pub(crate) fn as_ptr(&self) -> *mut sqlite3 { - self.0 .0.as_ptr() + self.0.as_ptr() + } + + pub(crate) fn as_non_null_ptr(&self) -> NonNull { + self.0 } #[inline] - pub(crate) fn to_ref(&self) -> ConnectionHandleRef { - ConnectionHandleRef(Arc::clone(&self.0)) + pub(crate) fn to_raw(&self) -> ConnectionHandleRaw { + ConnectionHandleRaw(self.0) + } + + pub(crate) fn last_insert_rowid(&mut self) -> i64 { + // SAFETY: we have exclusive access to the database handle + unsafe { sqlite3_last_insert_rowid(self.as_ptr()) } + } + + pub(crate) fn exec(&mut self, query: impl Into) -> Result<(), Error> { + let query = query.into(); + let query = CString::new(query).map_err(|_| err_protocol!("query contains nul bytes"))?; + + // SAFETY: we have exclusive access to the database handle + unsafe { + let status = sqlite3_exec( + self.as_ptr(), + query.as_ptr(), + // callback if we wanted result rows + None, + // callback data + ptr::null_mut(), + // out-pointer for the error message, we just use `SqliteError::new()` + ptr::null_mut(), + ); + + if status == SQLITE_OK { + Ok(()) + } else { + Err(SqliteError::new(self.as_ptr()).into()) + } + } + } +} + +impl ConnectionHandleRaw { + pub(crate) fn as_ptr(&self) -> *mut sqlite3 { + self.0.as_ptr() } } -impl Drop for HandleInner { +impl Drop for ConnectionHandle { fn drop(&mut self) { unsafe { // https://sqlite.org/c3ref/close.html diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index e001f08fa3..14234d92bd 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -1,59 +1,141 @@ +use std::cmp::Ordering; +use std::fmt::{self, Debug, Formatter}; +use std::ptr::NonNull; + +use futures_core::future::BoxFuture; +use futures_intrusive::sync::MutexGuard; +use futures_util::future; +use libsqlite3_sys::sqlite3; + +pub(crate) use handle::{ConnectionHandle, ConnectionHandleRaw}; + use crate::common::StatementCache; use crate::connection::{Connection, LogSettings}; use crate::error::Error; -use crate::sqlite::statement::{StatementWorker, VirtualStatement}; +use crate::sqlite::connection::establish::EstablishParams; +use crate::sqlite::connection::worker::ConnectionWorker; +use crate::sqlite::statement::VirtualStatement; use crate::sqlite::{Sqlite, SqliteConnectOptions}; use crate::transaction::Transaction; -use futures_core::future::BoxFuture; -use futures_util::future; -use libsqlite3_sys::sqlite3; -use std::cmp::Ordering; -use std::fmt::{self, Debug, Formatter}; -mod collation; +pub(crate) mod collation; mod describe; -pub(crate) mod establish; +mod establish; +mod execute; mod executor; mod explain; mod handle; -pub(crate) use handle::{ConnectionHandle, ConnectionHandleRef}; +mod worker; -/// A connection to a [Sqlite] database. +/// A connection to an open [Sqlite] database. +/// +/// Because SQLite is an in-process database accessed by blocking API calls, SQLx uses a background +/// thread and communicates with it via channels to allow non-blocking access to the database. +/// +/// Dropping this struct will signal the worker thread to quit and close the database, though +/// if an error occurs there is no way to pass it back to the user this way. +/// +/// You can explicitly call [`.close()`][Self::close] to ensure the database is closed successfully +/// or get an error otherwise. pub struct SqliteConnection { + pub(crate) worker: ConnectionWorker, + pub(crate) row_channel_size: usize, +} + +pub struct LockedSqliteHandle<'a> { + pub(crate) guard: MutexGuard<'a, ConnectionState>, +} + +pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, - pub(crate) worker: StatementWorker, // transaction status pub(crate) transaction_depth: usize, - // cache of semi-persistent statements - pub(crate) statements: StatementCache, - - // most recent non-persistent statement - pub(crate) statement: Option, + pub(crate) statements: Statements, log_settings: LogSettings, } +pub(crate) struct Statements { + // cache of semi-persistent statements + cached: StatementCache, + // most recent non-persistent statement + temp: Option, +} + impl SqliteConnection { - /// Returns the underlying sqlite3* connection handle + pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result { + let params = EstablishParams::from_options(options)?; + let worker = ConnectionWorker::establish(params).await?; + Ok(Self { + worker, + row_channel_size: options.row_channel_size, + }) + } + + /// Returns the underlying sqlite3* connection handle. + /// + /// ### Note + /// There is no synchronization using this method, beware that the background thread could + /// be making SQLite API calls concurrent to use of this method. + /// + /// You probably want to use [`.lock_handle()`][Self::lock_handle] to ensure that the worker thread is not using + /// the database concurrently. + #[deprecated = "Unsynchronized access is unsafe. See documentation for details."] pub fn as_raw_handle(&mut self) -> *mut sqlite3 { - self.handle.as_ptr() + self.worker.handle_raw.as_ptr() } + /// Apply a collation to the open database. + /// + /// See [`SqliteConnectOptions::collation()`] for details. + /// + /// ### Deprecated + /// Due to the rearchitecting of the SQLite driver, this method cannot actually function + /// synchronously and return the result directly from `sqlite3_create_collation_v2()`, so + /// it instead sends a message to the worker create the collation asynchronously. + /// If an error occurs it will simply be logged. + /// + /// Instead, you should specify the collation during the initial configuration with + /// [`SqliteConnectOptions::collation()`]. Then, if the collation fails to apply it will + /// return an error during the connection creation. When used with a [Pool][crate::pool::Pool], + /// this also ensures that the collation is applied to all connections automatically. + /// + /// Or if necessary, you can call [`.lock_handle()`][Self::lock_handle] + /// and create the collation directly with [`LockedSqliteHandle::create_collation()`]. + /// + /// [`Error::WorkerCrashed`] may still be returned if we could not communicate with the worker. + /// + /// Note that this may also block if the worker command channel is currently applying + /// backpressure. + #[deprecated = "Completes asynchronously. See documentation for details."] pub fn create_collation( &mut self, name: &str, compare: impl Fn(&str, &str) -> Ordering + Send + Sync + 'static, ) -> Result<(), Error> { - collation::create_collation(&self.handle, name, compare) + self.worker.create_collation(name, compare) + } + + /// Lock the SQLite database handle out from the worker thread so direct SQLite API calls can + /// be made safely. + /// + /// Returns an error if the worker thread crashed. + pub async fn lock_handle(&mut self) -> Result, Error> { + let guard = self.worker.unlock_db().await?; + + Ok(LockedSqliteHandle { guard }) } } impl Debug for SqliteConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("SqliteConnection").finish() + f.debug_struct("SqliteConnection") + .field("row_channel_size", &self.row_channel_size) + .field("cached_statements_size", &self.cached_statements_size()) + .finish() } } @@ -65,7 +147,7 @@ impl Connection for SqliteConnection { fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { Box::pin(async move { let shutdown = self.worker.shutdown(); - // Drop the statement worker and any outstanding statements, which should + // Drop the statement worker, which should // cover all references to the connection handle outside of the worker thread drop(self); // Ensure the worker thread has terminated @@ -73,9 +155,9 @@ impl Connection for SqliteConnection { }) } + /// Ensure the background worker thread is alive and accepting commands. fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { - // For SQLite connections, PING does effectively nothing - Box::pin(future::ok(())) + Box::pin(self.worker.ping()) } fn begin(&mut self) -> BoxFuture<'_, Result, Error>> @@ -86,19 +168,25 @@ impl Connection for SqliteConnection { } fn cached_statements_size(&self) -> usize { - self.statements.len() + self.worker + .shared + .cached_statements_size + .load(std::sync::atomic::Ordering::Acquire) } fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { - self.statements.clear(); + self.worker.clear_cache().await?; Ok(()) }) } #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { - // For SQLite, FLUSH does effectively nothing + // For SQLite, FLUSH does effectively nothing... + // Well, we could use this to ensure that the command channel has been cleared, + // but it would only develop a backlog if a lot of queries are executed and then cancelled + // partway through, and then this would only make that situation worse. Box::pin(future::ok(())) } @@ -108,10 +196,70 @@ impl Connection for SqliteConnection { } } -impl Drop for SqliteConnection { +impl LockedSqliteHandle<'_> { + /// Returns the underlying sqlite3* connection handle. + /// + /// As long as this `LockedSqliteHandle` exists, it is guaranteed that the background thread + /// is not making FFI calls on this database handle or any of its statements. + pub fn as_raw_handle(&mut self) -> NonNull { + self.guard.handle.as_non_null_ptr() + } + + /// Apply a collation to the open database. + /// + /// See [`SqliteConnectOptions::collation()`] for details. + pub fn create_collation( + &mut self, + name: &str, + compare: impl Fn(&str, &str) -> Ordering + Send + Sync + 'static, + ) -> Result<(), Error> { + collation::create_collation(&mut self.guard.handle, name, compare) + } +} + +impl Drop for ConnectionState { fn drop(&mut self) { // explicitly drop statements before the connection handle is dropped self.statements.clear(); - self.statement.take(); + } +} + +impl Statements { + fn new(capacity: usize) -> Self { + Statements { + cached: StatementCache::new(capacity), + temp: None, + } + } + + fn get(&mut self, query: &str, persistent: bool) -> Result<&mut VirtualStatement, Error> { + if !persistent || !self.cached.is_enabled() { + return Ok(self.temp.insert(VirtualStatement::new(query, false)?)); + } + + let exists = self.cached.contains_key(query); + + if !exists { + let statement = VirtualStatement::new(query, true)?; + self.cached.insert(query, statement); + } + + let statement = self.cached.get_mut(query).unwrap(); + + if exists { + // as this statement has been executed before, we reset before continuing + statement.reset()?; + } + + Ok(statement) + } + + fn len(&self) -> usize { + self.cached.len() + } + + fn clear(&mut self) { + self.cached.clear(); + self.temp = None; } } diff --git a/sqlx-core/src/sqlite/connection/worker.rs b/sqlx-core/src/sqlite/connection/worker.rs new file mode 100644 index 0000000000..527e363aa5 --- /dev/null +++ b/sqlx-core/src/sqlite/connection/worker.rs @@ -0,0 +1,388 @@ +use std::borrow::Cow; +use std::future::Future; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::thread; + +use either::Either; +use futures_channel::oneshot; +use futures_intrusive::sync::{Mutex, MutexGuard}; + +use crate::describe::Describe; +use crate::error::Error; +use crate::sqlite::connection::collation::create_collation; +use crate::sqlite::connection::describe::describe; +use crate::sqlite::connection::establish::EstablishParams; +use crate::sqlite::connection::ConnectionState; +use crate::sqlite::connection::{execute, ConnectionHandleRaw}; +use crate::sqlite::{Sqlite, SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement}; +use crate::transaction::{ + begin_ansi_transaction_sql, commit_ansi_transaction_sql, rollback_ansi_transaction_sql, +}; + +// Each SQLite connection has a dedicated thread. + +// TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce +// OS resource usage. Low priority because a high concurrent load for SQLite3 is very +// unlikely. + +pub(crate) struct ConnectionWorker { + command_tx: flume::Sender, + /// The `sqlite3` pointer. NOTE: access is unsynchronized! + pub(crate) handle_raw: ConnectionHandleRaw, + /// Mutex for locking access to the database. + pub(crate) shared: Arc, +} + +pub(crate) struct WorkerSharedState { + pub(crate) cached_statements_size: AtomicUsize, + pub(crate) conn: Mutex, +} + +enum Command { + Prepare { + query: Box, + tx: oneshot::Sender, Error>>, + }, + Describe { + query: Box, + tx: oneshot::Sender, Error>>, + }, + Execute { + query: Box, + arguments: Option>, + persistent: bool, + tx: flume::Sender, Error>>, + }, + Begin { + tx: oneshot::Sender>, + }, + Commit { + tx: oneshot::Sender>, + }, + Rollback { + tx: Option>>, + }, + CreateCollation { + create_collation: + Box Result<(), Error> + Send + Sync + 'static>, + }, + UnlockDb, + ClearCache { + tx: oneshot::Sender<()>, + }, + Ping { + tx: oneshot::Sender<()>, + }, + Shutdown { + tx: oneshot::Sender<()>, + }, +} + +impl ConnectionWorker { + pub(crate) async fn establish(params: EstablishParams) -> Result { + let (establish_tx, establish_rx) = oneshot::channel(); + + thread::Builder::new() + .name(params.thread_name.clone()) + .spawn(move || { + let (command_tx, command_rx) = flume::bounded(params.command_channel_size); + + let conn = match params.establish() { + Ok(conn) => conn, + Err(e) => { + establish_tx.send(Err(e)).ok(); + return; + } + }; + + let shared = Arc::new(WorkerSharedState { + cached_statements_size: AtomicUsize::new(0), + // note: must be fair because in `Command::UnlockDb` we unlock the mutex + // and then immediately try to relock it; an unfair mutex would immediately + // grant us the lock even if another task is waiting. + conn: Mutex::new(conn, true), + }); + let mut conn = shared.conn.try_lock().unwrap(); + + if establish_tx + .send(Ok(Self { + command_tx, + handle_raw: conn.handle.to_raw(), + shared: Arc::clone(&shared), + })) + .is_err() + { + return; + } + + for cmd in command_rx { + match cmd { + Command::Prepare { query, tx } => { + tx.send(prepare(&mut conn, &query).map(|prepared| { + update_cached_statements_size( + &conn, + &shared.cached_statements_size, + ); + prepared + })) + .ok(); + } + Command::Describe { query, tx } => { + tx.send(describe(&mut conn, &query)).ok(); + } + Command::Execute { + query, + arguments, + persistent, + tx, + } => { + let iter = match execute::iter(&mut conn, &query, arguments, persistent) + { + Ok(iter) => iter, + Err(e) => { + tx.send(Err(e)).ok(); + continue; + } + }; + + for res in iter { + if tx.send(res).is_err() { + break; + } + } + + update_cached_statements_size(&conn, &shared.cached_statements_size); + } + Command::Begin { tx } => { + let depth = conn.transaction_depth; + let res = + conn.handle + .exec(begin_ansi_transaction_sql(depth)) + .map(|_| { + conn.transaction_depth += 1; + }); + + tx.send(res).ok(); + } + Command::Commit { tx } => { + let depth = conn.transaction_depth; + + let res = if depth > 0 { + conn.handle + .exec(commit_ansi_transaction_sql(depth)) + .map(|_| { + conn.transaction_depth -= 1; + }) + } else { + Ok(()) + }; + + tx.send(res).ok(); + } + Command::Rollback { tx } => { + let depth = conn.transaction_depth; + + let res = if depth > 0 { + conn.handle + .exec(rollback_ansi_transaction_sql(depth)) + .map(|_| { + conn.transaction_depth -= 1; + }) + } else { + Ok(()) + }; + + if let Some(tx) = tx { + tx.send(res).ok(); + } + } + Command::CreateCollation { create_collation } => { + if let Err(e) = (create_collation)(&mut conn) { + log::warn!("error applying collation in background worker: {}", e); + } + } + Command::ClearCache { tx } => { + conn.statements.clear(); + update_cached_statements_size(&conn, &shared.cached_statements_size); + tx.send(()).ok(); + } + Command::UnlockDb => { + drop(conn); + conn = futures_executor::block_on(shared.conn.lock()); + } + Command::Ping { tx } => { + tx.send(()).ok(); + } + Command::Shutdown { tx } => { + // drop the connection reference before sending confirmation + // and ending the command loop + drop(conn); + let _ = tx.send(()); + return; + } + } + } + })?; + + establish_rx.await.map_err(|_| Error::WorkerCrashed)? + } + + pub(crate) async fn prepare(&mut self, query: &str) -> Result, Error> { + self.oneshot_cmd(|tx| Command::Prepare { + query: query.into(), + tx, + }) + .await? + } + + pub(crate) async fn describe(&mut self, query: &str) -> Result, Error> { + self.oneshot_cmd(|tx| Command::Describe { + query: query.into(), + tx, + }) + .await? + } + + pub(crate) async fn execute( + &mut self, + query: &str, + args: Option>, + chan_size: usize, + persistent: bool, + ) -> Result, Error>>, Error> { + let (tx, rx) = flume::bounded(chan_size); + + self.command_tx + .send_async(Command::Execute { + query: query.into(), + arguments: args.map(SqliteArguments::into_static), + persistent, + tx, + }) + .await + .map_err(|_| Error::WorkerCrashed)?; + + Ok(rx) + } + + pub(crate) async fn begin(&mut self) -> Result<(), Error> { + self.oneshot_cmd(|tx| Command::Begin { tx }).await? + } + + pub(crate) async fn commit(&mut self) -> Result<(), Error> { + self.oneshot_cmd(|tx| Command::Commit { tx }).await? + } + + pub(crate) async fn rollback(&mut self) -> Result<(), Error> { + self.oneshot_cmd(|tx| Command::Rollback { tx: Some(tx) }) + .await? + } + + pub(crate) fn start_rollback(&mut self) -> Result<(), Error> { + self.command_tx + .send(Command::Rollback { tx: None }) + .map_err(|_| Error::WorkerCrashed) + } + + pub(crate) async fn ping(&mut self) -> Result<(), Error> { + self.oneshot_cmd(|tx| Command::Ping { tx }).await + } + + async fn oneshot_cmd(&mut self, command: F) -> Result + where + F: FnOnce(oneshot::Sender) -> Command, + { + let (tx, rx) = oneshot::channel(); + + self.command_tx + .send_async(command(tx)) + .await + .map_err(|_| Error::WorkerCrashed)?; + + rx.await.map_err(|_| Error::WorkerCrashed) + } + + pub fn create_collation( + &mut self, + name: &str, + compare: impl Fn(&str, &str) -> std::cmp::Ordering + Send + Sync + 'static, + ) -> Result<(), Error> { + let name = name.to_string(); + + self.command_tx + .send(Command::CreateCollation { + create_collation: Box::new(move |conn| { + create_collation(&mut conn.handle, &name, compare) + }), + }) + .map_err(|_| Error::WorkerCrashed)?; + Ok(()) + } + + pub(crate) async fn clear_cache(&mut self) -> Result<(), Error> { + self.oneshot_cmd(|tx| Command::ClearCache { tx }).await + } + + pub(crate) async fn unlock_db(&mut self) -> Result, Error> { + let (guard, res) = futures_util::future::join( + // we need to join the wait queue for the lock before we send the message + self.shared.conn.lock(), + self.command_tx.send_async(Command::UnlockDb), + ) + .await; + + res.map_err(|_| Error::WorkerCrashed)?; + + Ok(guard) + } + + /// Send a command to the worker to shut down the processing thread. + /// + /// A `WorkerCrashed` error may be returned if the thread has already stopped. + pub(crate) fn shutdown(&mut self) -> impl Future> { + let (tx, rx) = oneshot::channel(); + + let send_res = self + .command_tx + .send(Command::Shutdown { tx }) + .map_err(|_| Error::WorkerCrashed); + + async move { + send_res?; + + // wait for the response + rx.await.map_err(|_| Error::WorkerCrashed) + } + } +} + +fn prepare(conn: &mut ConnectionState, query: &str) -> Result, Error> { + // prepare statement object (or checkout from cache) + let statement = conn.statements.get(query, true)?; + + let mut parameters = 0; + let mut columns = None; + let mut column_names = None; + + while let Some(statement) = statement.prepare_next(&mut conn.handle)? { + parameters += statement.handle.bind_parameter_count(); + + // the first non-empty statement is chosen as the statement we pull columns from + if !statement.columns.is_empty() && columns.is_none() { + columns = Some(Arc::clone(statement.columns)); + column_names = Some(Arc::clone(statement.column_names)); + } + } + + Ok(SqliteStatement { + sql: Cow::Owned(query.to_string()), + columns: columns.unwrap_or_default(), + column_names: column_names.unwrap_or_default(), + parameters, + }) +} + +fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) { + size.store(conn.statements.len(), Ordering::Release); +} diff --git a/sqlx-core/src/sqlite/mod.rs b/sqlx-core/src/sqlite/mod.rs index 5be8cbfd92..7810641716 100644 --- a/sqlx-core/src/sqlite/mod.rs +++ b/sqlx-core/src/sqlite/mod.rs @@ -5,6 +5,21 @@ // invariants. #![allow(unsafe_code)] +pub use arguments::{SqliteArgumentValue, SqliteArguments}; +pub use column::SqliteColumn; +pub use connection::{LockedSqliteHandle, SqliteConnection}; +pub use database::Sqlite; +pub use error::SqliteError; +pub use options::{ + SqliteAutoVacuum, SqliteConnectOptions, SqliteJournalMode, SqliteLockingMode, SqliteSynchronous, +}; +pub use query_result::SqliteQueryResult; +pub use row::SqliteRow; +pub use statement::SqliteStatement; +pub use transaction::SqliteTransactionManager; +pub use type_info::SqliteTypeInfo; +pub use value::{SqliteValue, SqliteValueRef}; + use crate::executor::Executor; mod arguments; @@ -24,21 +39,6 @@ mod value; #[cfg(feature = "migrate")] mod migrate; -pub use arguments::{SqliteArgumentValue, SqliteArguments}; -pub use column::SqliteColumn; -pub use connection::SqliteConnection; -pub use database::Sqlite; -pub use error::SqliteError; -pub use options::{ - SqliteAutoVacuum, SqliteConnectOptions, SqliteJournalMode, SqliteLockingMode, SqliteSynchronous, -}; -pub use query_result::SqliteQueryResult; -pub use row::SqliteRow; -pub use statement::SqliteStatement; -pub use transaction::SqliteTransactionManager; -pub use type_info::SqliteTypeInfo; -pub use value::{SqliteValue, SqliteValueRef}; - /// An alias for [`Pool`][crate::pool::Pool], specialized for SQLite. pub type SqlitePool = crate::pool::Pool; diff --git a/sqlx-core/src/sqlite/options/connect.rs b/sqlx-core/src/sqlite/options/connect.rs index 60eb3bd275..a087c43c3f 100644 --- a/sqlx-core/src/sqlite/options/connect.rs +++ b/sqlx-core/src/sqlite/options/connect.rs @@ -1,7 +1,6 @@ use crate::connection::ConnectOptions; use crate::error::Error; use crate::executor::Executor; -use crate::sqlite::connection::establish::establish; use crate::sqlite::{SqliteConnectOptions, SqliteConnection}; use futures_core::future::BoxFuture; use log::LevelFilter; @@ -16,7 +15,7 @@ impl ConnectOptions for SqliteConnectOptions { Self::Connection: Sized, { Box::pin(async move { - let mut conn = establish(self).await?; + let mut conn = SqliteConnection::establish(self).await?; // send an initial sql statement comprised of options let mut init = String::new(); @@ -27,7 +26,7 @@ impl ConnectOptions for SqliteConnectOptions { write!(init, "PRAGMA key = {}; ", pragma_key_password).ok(); } - for (key, value) in self.pragmas.iter() { + for (key, value) in &self.pragmas { // Since we've already written the possible `key` pragma // above, we shall skip it now. if key == "key" { @@ -38,6 +37,14 @@ impl ConnectOptions for SqliteConnectOptions { conn.execute(&*init).await?; + if !self.collations.is_empty() { + let mut locked = conn.lock_handle().await?; + + for collation in &self.collations { + collation.create(&mut locked.guard.handle)?; + } + } + Ok(conn) }) } diff --git a/sqlx-core/src/sqlite/options/mod.rs b/sqlx-core/src/sqlite/options/mod.rs index 9db122f355..3420273556 100644 --- a/sqlx-core/src/sqlite/options/mod.rs +++ b/sqlx-core/src/sqlite/options/mod.rs @@ -11,9 +11,13 @@ use crate::connection::LogSettings; pub use auto_vacuum::SqliteAutoVacuum; pub use journal_mode::SqliteJournalMode; pub use locking_mode::SqliteLockingMode; +use std::cmp::Ordering; +use std::sync::Arc; use std::{borrow::Cow, time::Duration}; pub use synchronous::SqliteSynchronous; +use crate::common::DebugFn; +use crate::sqlite::connection::collation::Collation; use indexmap::IndexMap; /// Options and flags which can be used to configure a SQLite connection. @@ -61,7 +65,14 @@ pub struct SqliteConnectOptions { pub(crate) log_settings: LogSettings, pub(crate) immutable: bool, pub(crate) pragmas: IndexMap, Cow<'static, str>>, + + pub(crate) command_channel_size: usize, + pub(crate) row_channel_size: usize, + + pub(crate) collations: Vec, + pub(crate) serialized: bool, + pub(crate) thread_name: Arc String + Send + Sync + 'static>>, } impl Default for SqliteConnectOptions { @@ -71,6 +82,9 @@ impl Default for SqliteConnectOptions { } impl SqliteConnectOptions { + /// Construct `Self` with default options. + /// + /// See the source of this method for the current defaults. pub fn new() -> Self { // set default pragmas let mut pragmas: IndexMap, Cow<'static, str>> = IndexMap::new(); @@ -110,7 +124,11 @@ impl SqliteConnectOptions { log_settings: Default::default(), immutable: false, pragmas, + collations: Default::default(), serialized: false, + thread_name: Arc::new(DebugFn(|id| format!("sqlx-sqlite-worker-{}", id))), + command_channel_size: 50, + row_channel_size: 50, } } @@ -232,6 +250,44 @@ impl SqliteConnectOptions { self } + /// Add a custom collation for comparing strings in SQL. + /// + /// If a collation with the same name already exists, it will be replaced. + /// + /// See [`sqlite3_create_collation()`](https://www.sqlite.org/c3ref/create_collation.html) for details. + /// + /// Note this excerpt: + /// > The collating function must obey the following properties for all strings A, B, and C: + /// > + /// > If A==B then B==A. + /// > If A==B and B==C then A==C. + /// > If A\A. + /// > If A + /// > If a collating function fails any of the above constraints and that collating function is + /// > registered and used, then the behavior of SQLite is undefined. + pub fn collation(mut self, name: N, collate: F) -> Self + where + N: Into>, + F: Fn(&str, &str) -> Ordering + Send + Sync + 'static, + { + self.collations.push(Collation::new(name, collate)); + self + } + + /// Set to `true` to signal to SQLite that the database file is on read-only media. + /// + /// If enabled, SQLite assumes the database file _cannot_ be modified, even by higher + /// privileged processes, and so disables locking and change detection. This is intended + /// to improve performance but can produce incorrect query results or errors if the file + /// _does_ change. + /// + /// Note that this is different from the `SQLITE_OPEN_READONLY` flag set by + /// [`.read_only()`][Self::read_only], though the documentation suggests that this + /// does _imply_ `SQLITE_OPEN_READONLY`. + /// + /// See [`sqlite3_open`](https://www.sqlite.org/capi3ref.html#sqlite3_open) (subheading + /// "URI Filenames") for details. pub fn immutable(mut self, immutable: bool) -> Self { self.immutable = immutable; self @@ -242,8 +298,49 @@ impl SqliteConnectOptions { /// The default setting is `false` corersponding to using `OPEN_NOMUTEX`, if `true` then `OPEN_FULLMUTEX`. /// /// See [open](https://www.sqlite.org/c3ref/open.html) for more details. + /// + /// ### Note + /// Setting this to `true` may help if you are getting access violation errors or segmentation + /// faults, but will also incur a significant performance penalty. You should leave this + /// set to `false` if at all possible. + /// + /// If you do end up needing to set this to `true` for some reason, please + /// [open an issue](https://github.com/launchbadge/sqlx/issues/new/choose) as this may indicate + /// a concurrency bug in SQLx. Please provide clear instructions for reproducing the issue, + /// including a sample database schema if applicable. pub fn serialized(mut self, serialized: bool) -> Self { self.serialized = serialized; self } + + /// Provide a callback to generate the name of the background worker thread. + /// + /// The value passed to the callback is an auto-incremented integer for use as the thread ID. + pub fn thread_name( + mut self, + generator: impl Fn(u64) -> String + Send + Sync + 'static, + ) -> Self { + self.thread_name = Arc::new(DebugFn(generator)); + self + } + + /// Set the maximum number of commands to buffer for the worker thread before backpressure is + /// applied. + /// + /// Given that most commands sent to the worker thread involve waiting for a result, + /// the command channel is unlikely to fill up unless a lot queries are executed in a short + /// period but cancelled before their full resultsets are returned. + pub fn command_buffer_size(mut self, size: usize) -> Self { + self.command_channel_size = size; + self + } + + /// Set the maximum number of rows to buffer back to the calling task when a query is executed. + /// + /// If the calling task cannot keep up, backpressure will be applied to the worker thread + /// in order to limit CPU and memory usage. + pub fn row_buffer_size(mut self, size: usize) -> Self { + self.row_channel_size = size; + self + } } diff --git a/sqlx-core/src/sqlite/row.rs b/sqlx-core/src/sqlite/row.rs index 4199915fe1..6caefd52b2 100644 --- a/sqlx-core/src/sqlite/row.rs +++ b/sqlx-core/src/sqlite/row.rs @@ -1,9 +1,6 @@ #![allow(clippy::rc_buffer)] -use std::ptr::null_mut; -use std::slice; -use std::sync::atomic::{AtomicPtr, Ordering}; -use std::sync::{Arc, Weak}; +use std::sync::Arc; use crate::HashMap; @@ -11,23 +8,12 @@ use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::row::Row; -use crate::sqlite::statement::{StatementHandle, StatementHandleRef}; +use crate::sqlite::statement::StatementHandle; use crate::sqlite::{Sqlite, SqliteColumn, SqliteValue, SqliteValueRef}; /// Implementation of [`Row`] for SQLite. pub struct SqliteRow { - // Raw handle of the SQLite statement - // This is valid to access IFF the atomic [values] is null - // The way this works is that the executor retains a weak reference to - // [values] after the Row is created and yielded downstream. - // IF the user drops the Row before iterating the stream (so - // nearly all of our internal stream iterators), the executor moves on; otherwise, - // it actually inflates this row with a list of owned sqlite3 values. - pub(crate) statement: StatementHandleRef, - - pub(crate) values: Arc>, - pub(crate) num_values: usize, - + pub(crate) values: Box<[SqliteValue]>, pub(crate) columns: Arc>, pub(crate) column_names: Arc>, } @@ -44,37 +30,11 @@ unsafe impl Send for SqliteRow {} unsafe impl Sync for SqliteRow {} impl SqliteRow { - // creates a new row that is internally referencing the **current** state of the statement - // returns a weak reference to an atomic list where the executor should inflate if its going - // to increment the statement with [step] pub(crate) fn current( - statement: StatementHandleRef, + statement: &StatementHandle, columns: &Arc>, column_names: &Arc>, - ) -> (Self, Weak>) { - let values = Arc::new(AtomicPtr::new(null_mut())); - let weak_values = Arc::downgrade(&values); - let size = statement.column_count(); - - let row = Self { - statement, - values, - num_values: size, - columns: Arc::clone(columns), - column_names: Arc::clone(column_names), - }; - - (row, weak_values) - } - - // inflates this Row into memory as a list of owned, protected SQLite value objects - // this is called by the - #[allow(clippy::needless_range_loop)] - pub(crate) fn inflate( - statement: &StatementHandle, - columns: &[SqliteColumn], - values_ref: &AtomicPtr, - ) { + ) -> Self { let size = statement.column_count(); let mut values = Vec::with_capacity(size); @@ -86,20 +46,10 @@ impl SqliteRow { }); } - // decay the array signifier and become just a normal, leaked array - let values_ptr = Box::into_raw(values.into_boxed_slice()) as *mut SqliteValue; - - // store in the atomic ptr storage - values_ref.store(values_ptr, Ordering::Release); - } - - pub(crate) fn inflate_if_needed( - statement: &StatementHandle, - columns: &[SqliteColumn], - weak_values_ref: Option>>, - ) { - if let Some(v) = weak_values_ref.and_then(|v| v.upgrade()) { - SqliteRow::inflate(statement, &columns, &v); + Self { + values: values.into_boxed_slice(), + columns: Arc::clone(columns), + column_names: Arc::clone(column_names), } } } @@ -116,34 +66,7 @@ impl Row for SqliteRow { I: ColumnIndex, { let index = index.index(self)?; - - let values_ptr = self.values.load(Ordering::Acquire); - if !values_ptr.is_null() { - // we have raw value data, we should use that - let values: &[SqliteValue] = - unsafe { slice::from_raw_parts(values_ptr, self.num_values) }; - - Ok(SqliteValueRef::value(&values[index])) - } else { - Ok(SqliteValueRef::statement( - &self.statement, - self.columns[index].type_info.clone(), - index, - )) - } - } -} - -impl Drop for SqliteRow { - fn drop(&mut self) { - // if there is a non-null pointer stored here, we need to re-load and drop it - let values_ptr = self.values.load(Ordering::Acquire); - if !values_ptr.is_null() { - let values: &mut [SqliteValue] = - unsafe { slice::from_raw_parts_mut(values_ptr, self.num_values) }; - - let _ = unsafe { Box::from_raw(values) }; - } + Ok(SqliteValueRef::value(&self.values[index])) } } diff --git a/sqlx-core/src/sqlite/statement/handle.rs b/sqlx-core/src/sqlite/statement/handle.rs index 27e7b59020..4f196e55ab 100644 --- a/sqlx-core/src/sqlite/statement/handle.rs +++ b/sqlx-core/src/sqlite/statement/handle.rs @@ -14,46 +14,31 @@ use libsqlite3_sys::{ sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name, sqlite3_column_origin_name, sqlite3_column_table_name, - sqlite3_column_type, sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_sql, - sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, - SQLITE_MISUSE, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8, + sqlite3_column_type, sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_reset, + sqlite3_sql, sqlite3_step, sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata, + sqlite3_value, SQLITE_DONE, SQLITE_MISUSE, SQLITE_OK, SQLITE_ROW, SQLITE_TRANSIENT, + SQLITE_UTF8, }; use crate::error::{BoxDynError, Error}; -use crate::sqlite::connection::ConnectionHandleRef; use crate::sqlite::type_info::DataType; use crate::sqlite::{SqliteError, SqliteTypeInfo}; -use std::ops::Deref; -use std::sync::Arc; #[derive(Debug)] pub(crate) struct StatementHandle(NonNull); -// wrapper for `Arc` which also holds a reference to the `ConnectionHandle` -#[derive(Clone, Debug)] -pub(crate) struct StatementHandleRef { - // NOTE: the ordering of fields here determines the drop order: - // https://doc.rust-lang.org/reference/destructors.html#destructors - // the statement *must* be dropped before the connection - statement: Arc, - connection: ConnectionHandleRef, -} - // access to SQLite3 statement handles are safe to send and share between threads // as long as the `sqlite3_step` call is serialized. unsafe impl Send for StatementHandle {} -unsafe impl Sync for StatementHandle {} +// might use some of this later +#[allow(dead_code)] impl StatementHandle { pub(super) fn new(ptr: NonNull) -> Self { Self(ptr) } - pub(crate) fn as_ptr(&self) -> *mut sqlite3_stmt { - self.0.as_ptr() - } - #[inline] pub(super) unsafe fn db_handle(&self) -> *mut sqlite3 { // O(c) access to the connection handle for this statement handle @@ -306,13 +291,26 @@ impl StatementHandle { unsafe { sqlite3_clear_bindings(self.0.as_ptr()) }; } - pub(crate) fn to_ref( - self: &Arc, - conn: ConnectionHandleRef, - ) -> StatementHandleRef { - StatementHandleRef { - statement: Arc::clone(self), - connection: conn, + pub(crate) fn reset(&mut self) -> Result<(), SqliteError> { + // SAFETY: we have exclusive access to the handle + unsafe { + if sqlite3_reset(self.0.as_ptr()) != SQLITE_OK { + return Err(SqliteError::new(self.db_handle())); + } + } + + Ok(()) + } + + pub(crate) fn step(&mut self) -> Result { + // SAFETY: we have exclusive access to the handle + unsafe { + match sqlite3_step(self.0.as_ptr()) { + SQLITE_ROW => Ok(true), + SQLITE_DONE => Ok(false), + SQLITE_MISUSE => panic!("misuse!"), + _ => Err(SqliteError::new(self.db_handle())), + } } } } @@ -335,11 +333,3 @@ impl Drop for StatementHandle { } } } - -impl Deref for StatementHandleRef { - type Target = StatementHandle; - - fn deref(&self) -> &Self::Target { - &self.statement - } -} diff --git a/sqlx-core/src/sqlite/statement/mod.rs b/sqlx-core/src/sqlite/statement/mod.rs index 97ca9f8685..759aca5539 100644 --- a/sqlx-core/src/sqlite/statement/mod.rs +++ b/sqlx-core/src/sqlite/statement/mod.rs @@ -10,11 +10,9 @@ use std::sync::Arc; mod handle; mod r#virtual; -mod worker; -pub(crate) use handle::{StatementHandle, StatementHandleRef}; +pub(crate) use handle::StatementHandle; pub(crate) use r#virtual::VirtualStatement; -pub(crate) use worker::StatementWorker; #[derive(Debug, Clone)] #[allow(clippy::rc_buffer)] diff --git a/sqlx-core/src/sqlite/statement/virtual.rs b/sqlx-core/src/sqlite/statement/virtual.rs index 3da6d33d64..aa3f16a027 100644 --- a/sqlx-core/src/sqlite/statement/virtual.rs +++ b/sqlx-core/src/sqlite/statement/virtual.rs @@ -3,106 +3,58 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::sqlite::connection::ConnectionHandle; -use crate::sqlite::statement::{StatementHandle, StatementWorker}; -use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue}; +use crate::sqlite::statement::StatementHandle; +use crate::sqlite::{SqliteColumn, SqliteError}; use crate::HashMap; use bytes::{Buf, Bytes}; use libsqlite3_sys::{ sqlite3, sqlite3_prepare_v3, sqlite3_stmt, SQLITE_OK, SQLITE_PREPARE_PERSISTENT, }; use smallvec::SmallVec; -use std::i32; use std::os::raw::c_char; use std::ptr::{null, null_mut, NonNull}; -use std::sync::{atomic::AtomicPtr, Arc, Weak}; +use std::sync::Arc; +use std::{cmp, i32}; // A virtual statement consists of *zero* or more raw SQLite3 statements. We chop up a SQL statement // on `;` to support multiple statements in one query. #[derive(Debug)] -pub(crate) struct VirtualStatement { +pub struct VirtualStatement { persistent: bool, - index: usize, - // tail of the most recently prepared SQL statement within this container + /// the current index of the actual statement that is executing + /// if `None`, no statement is executing and `prepare()` must be called; + /// if `Some(self.handles.len())` and `self.tail.is_empty()`, + /// there are no more statements to execute and `reset()` must be called + index: Option, + + /// tail of the most recently prepared SQL statement within this container tail: Bytes, - // underlying sqlite handles for each inner statement - // a SQL query string in SQLite is broken up into N statements - // we use a [`SmallVec`] to optimize for the most likely case of a single statement - pub(crate) handles: SmallVec<[Arc; 1]>, + /// underlying sqlite handles for each inner statement + /// a SQL query string in SQLite is broken up into N statements + /// we use a [`SmallVec`] to optimize for the most likely case of a single statement + pub(crate) handles: SmallVec<[StatementHandle; 1]>, // each set of columns pub(crate) columns: SmallVec<[Arc>; 1]>, // each set of column names pub(crate) column_names: SmallVec<[Arc>; 1]>, - - // weak reference to the previous row from this connection - // we use the notice of a successful upgrade of this reference as an indicator that the - // row is still around, in which we then inflate the row such that we can let SQLite - // clobber the memory allocation for the row - pub(crate) last_row_values: SmallVec<[Option>>; 1]>, } -fn prepare( - conn: *mut sqlite3, - query: &mut Bytes, - persistent: bool, -) -> Result, Error> { - let mut flags = 0; - - if persistent { - // SQLITE_PREPARE_PERSISTENT - // The SQLITE_PREPARE_PERSISTENT flag is a hint to the query - // planner that the prepared statement will be retained for a long time - // and probably reused many times. - flags |= SQLITE_PREPARE_PERSISTENT; - } - - while !query.is_empty() { - let mut statement_handle: *mut sqlite3_stmt = null_mut(); - let mut tail: *const c_char = null(); - - let query_ptr = query.as_ptr() as *const c_char; - let query_len = query.len() as i32; - - // - let status = unsafe { - sqlite3_prepare_v3( - conn, - query_ptr, - query_len, - flags as u32, - &mut statement_handle, - &mut tail, - ) - }; - - if status != SQLITE_OK { - return Err(SqliteError::new(conn).into()); - } - - // tail should point to the first byte past the end of the first SQL - // statement in zSql. these routines only compile the first statement, - // so tail is left pointing to what remains un-compiled. - - let n = (tail as usize) - (query_ptr as usize); - query.advance(n); - - if let Some(handle) = NonNull::new(statement_handle) { - return Ok(Some(StatementHandle::new(handle))); - } - } - - Ok(None) +pub struct PreparedStatement<'a> { + pub(crate) handle: &'a mut StatementHandle, + pub(crate) columns: &'a Arc>, + pub(crate) column_names: &'a Arc>, } impl VirtualStatement { pub(crate) fn new(mut query: &str, persistent: bool) -> Result { query = query.trim(); - if query.len() > i32::MAX as usize { + if query.len() > i32::max_value() as usize { return Err(err_protocol!( "query string must be smaller than {} bytes", i32::MAX @@ -113,26 +65,23 @@ impl VirtualStatement { persistent, tail: Bytes::from(String::from(query)), handles: SmallVec::with_capacity(1), - index: 0, + index: None, columns: SmallVec::with_capacity(1), column_names: SmallVec::with_capacity(1), - last_row_values: SmallVec::with_capacity(1), }) } - pub(crate) fn prepare( + pub(crate) fn prepare_next( &mut self, conn: &mut ConnectionHandle, - ) -> Result< - Option<( - &Arc, - &mut Arc>, - &Arc>, - &mut Option>>, - )>, - Error, - > { - while self.handles.len() == self.index { + ) -> Result>, Error> { + // increment `self.index` up to `self.handles.len()` + self.index = self + .index + .map(|idx| cmp::min(idx + 1, self.handles.len())) + .or(Some(0)); + + while self.handles.len() <= self.index.unwrap_or(0) { if self.tail.is_empty() { return Ok(None); } @@ -158,34 +107,30 @@ impl VirtualStatement { column_names.insert(name, i); } - self.handles.push(Arc::new(statement)); + self.handles.push(statement); self.columns.push(Arc::new(columns)); self.column_names.push(Arc::new(column_names)); - self.last_row_values.push(None); } } - let index = self.index; - self.index += 1; - - Ok(Some(( - &self.handles[index], - &mut self.columns[index], - &self.column_names[index], - &mut self.last_row_values[index], - ))) + Ok(self.current()) } - pub(crate) async fn reset(&mut self, worker: &mut StatementWorker) -> Result<(), Error> { - self.index = 0; + pub fn current(&mut self) -> Option> { + self.index + .filter(|&idx| idx < self.handles.len()) + .map(move |idx| PreparedStatement { + handle: &mut self.handles[idx], + columns: &self.columns[idx], + column_names: &self.column_names[idx], + }) + } - for (i, handle) in self.handles.iter().enumerate() { - SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); + pub fn reset(&mut self) -> Result<(), Error> { + self.index = None; - // Reset A Prepared Statement Object - // https://www.sqlite.org/c3ref/reset.html - // https://www.sqlite.org/c3ref/clear_bindings.html - worker.reset(handle).await?; + for handle in self.handles.iter_mut() { + handle.reset()?; handle.clear_bindings(); } @@ -193,10 +138,55 @@ impl VirtualStatement { } } -impl Drop for VirtualStatement { - fn drop(&mut self) { - for (i, handle) in self.handles.drain(..).enumerate() { - SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); +fn prepare( + conn: *mut sqlite3, + query: &mut Bytes, + persistent: bool, +) -> Result, Error> { + let mut flags = 0; + + if persistent { + // SQLITE_PREPARE_PERSISTENT + // The SQLITE_PREPARE_PERSISTENT flag is a hint to the query + // planner that the prepared statement will be retained for a long time + // and probably reused many times. + flags |= SQLITE_PREPARE_PERSISTENT; + } + + while !query.is_empty() { + let mut statement_handle: *mut sqlite3_stmt = null_mut(); + let mut tail: *const c_char = null(); + + let query_ptr = query.as_ptr() as *const c_char; + let query_len = query.len() as i32; + + // + let status = unsafe { + sqlite3_prepare_v3( + conn, + query_ptr, + query_len, + flags as u32, + &mut statement_handle, + &mut tail, + ) + }; + + if status != SQLITE_OK { + return Err(SqliteError::new(conn).into()); + } + + // tail should point to the first byte past the end of the first SQL + // statement in zSql. these routines only compile the first statement, + // so tail is left pointing to what remains un-compiled. + + let n = (tail as usize) - (query_ptr as usize); + query.advance(n); + + if let Some(handle) = NonNull::new(statement_handle) { + return Ok(Some(StatementHandle::new(handle))); } } + + Ok(None) } diff --git a/sqlx-core/src/sqlite/statement/worker.rs b/sqlx-core/src/sqlite/statement/worker.rs deleted file mode 100644 index 5a06f637b0..0000000000 --- a/sqlx-core/src/sqlite/statement/worker.rs +++ /dev/null @@ -1,161 +0,0 @@ -use crate::error::Error; -use crate::sqlite::statement::StatementHandle; -use crossbeam_channel::{unbounded, Sender}; -use either::Either; -use futures_channel::oneshot; -use std::sync::{Arc, Weak}; -use std::thread; - -use crate::sqlite::connection::ConnectionHandleRef; - -use libsqlite3_sys::{sqlite3_reset, sqlite3_step, SQLITE_DONE, SQLITE_ROW}; -use std::future::Future; - -// Each SQLite connection has a dedicated thread. - -// TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce -// OS resource usage. Low priority because a high concurrent load for SQLite3 is very -// unlikely. - -pub(crate) struct StatementWorker { - tx: Sender, -} - -enum StatementWorkerCommand { - Step { - statement: Weak, - tx: oneshot::Sender, Error>>, - }, - Reset { - statement: Weak, - tx: oneshot::Sender<()>, - }, - Shutdown { - tx: oneshot::Sender<()>, - }, -} - -impl StatementWorker { - pub(crate) fn new(conn: ConnectionHandleRef) -> Self { - let (tx, rx) = unbounded(); - - thread::spawn(move || { - for cmd in rx { - match cmd { - StatementWorkerCommand::Step { statement, tx } => { - let statement = if let Some(statement) = statement.upgrade() { - statement - } else { - // statement is already finalized, the sender shouldn't be expecting a response - continue; - }; - - // SAFETY: only the `StatementWorker` calls this function - let status = unsafe { sqlite3_step(statement.as_ptr()) }; - let result = match status { - SQLITE_ROW => Ok(Either::Right(())), - SQLITE_DONE => Ok(Either::Left(statement.changes())), - _ => Err(statement.last_error().into()), - }; - - let _ = tx.send(result); - } - StatementWorkerCommand::Reset { statement, tx } => { - if let Some(statement) = statement.upgrade() { - // SAFETY: this must be the only place we call `sqlite3_reset` - unsafe { sqlite3_reset(statement.as_ptr()) }; - - // `sqlite3_reset()` always returns either `SQLITE_OK` - // or the last error code for the statement, - // which should have already been handled; - // so it's assumed the return value is safe to ignore. - // - // https://www.sqlite.org/c3ref/reset.html - - let _ = tx.send(()); - } - } - StatementWorkerCommand::Shutdown { tx } => { - // drop the connection reference before sending confirmation - // and ending the command loop - drop(conn); - let _ = tx.send(()); - return; - } - } - } - - // SAFETY: we need to make sure a strong ref to `conn` always outlives anything in `rx` - drop(conn); - }); - - Self { tx } - } - - pub(crate) async fn step( - &mut self, - statement: &Arc, - ) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - - self.tx - .send(StatementWorkerCommand::Step { - statement: Arc::downgrade(statement), - tx, - }) - .map_err(|_| Error::WorkerCrashed)?; - - rx.await.map_err(|_| Error::WorkerCrashed)? - } - - /// Send a command to the worker to execute `sqlite3_reset()` next. - /// - /// This method is written to execute the sending of the command eagerly so - /// you do not need to await the returned future unless you want to. - /// - /// The only error is `WorkerCrashed` as `sqlite3_reset()` returns the last error - /// in the statement execution which should have already been handled from `step()`. - pub(crate) fn reset( - &mut self, - statement: &Arc, - ) -> impl Future> { - // execute the sending eagerly so we don't need to spawn the future - let (tx, rx) = oneshot::channel(); - - let send_res = self - .tx - .send(StatementWorkerCommand::Reset { - statement: Arc::downgrade(statement), - tx, - }) - .map_err(|_| Error::WorkerCrashed); - - async move { - send_res?; - - // wait for the response - rx.await.map_err(|_| Error::WorkerCrashed) - } - } - - /// Send a command to the worker to shut down the processing thread. - /// - /// A `WorkerCrashed` error may be returned if the thread has already stopped. - /// Subsequent calls to `step()`, `reset()`, or this method will fail with - /// `WorkerCrashed`. Ensure that any associated statements are dropped first. - pub(crate) fn shutdown(&mut self) -> impl Future> { - let (tx, rx) = oneshot::channel(); - - let send_res = self - .tx - .send(StatementWorkerCommand::Shutdown { tx }) - .map_err(|_| Error::WorkerCrashed); - - async move { - send_res?; - - // wait for the response - rx.await.map_err(|_| Error::WorkerCrashed) - } - } -} diff --git a/sqlx-core/src/sqlite/transaction.rs b/sqlx-core/src/sqlite/transaction.rs index aad9c30cb5..cfd39ba4fa 100644 --- a/sqlx-core/src/sqlite/transaction.rs +++ b/sqlx-core/src/sqlite/transaction.rs @@ -1,15 +1,8 @@ -use std::ptr; - use futures_core::future::BoxFuture; -use libsqlite3_sys::{sqlite3_exec, SQLITE_OK}; use crate::error::Error; -use crate::executor::Executor; -use crate::sqlite::{Sqlite, SqliteConnection, SqliteError}; -use crate::transaction::{ - begin_ansi_transaction_sql, commit_ansi_transaction_sql, rollback_ansi_transaction_sql, - TransactionManager, -}; +use crate::sqlite::{Sqlite, SqliteConnection}; +use crate::transaction::TransactionManager; /// Implementation of [`TransactionManager`] for SQLite. pub struct SqliteTransactionManager; @@ -18,71 +11,18 @@ impl TransactionManager for SqliteTransactionManager { type Database = Sqlite; fn begin(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { - let depth = conn.transaction_depth; - - conn.execute(&*begin_ansi_transaction_sql(depth)).await?; - conn.transaction_depth = depth + 1; - - Ok(()) - }) + Box::pin(conn.worker.begin()) } fn commit(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { - let depth = conn.transaction_depth; - - if depth > 0 { - conn.execute(&*commit_ansi_transaction_sql(depth)).await?; - conn.transaction_depth = depth - 1; - } - - Ok(()) - }) + Box::pin(conn.worker.commit()) } fn rollback(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { - let depth = conn.transaction_depth; - - if depth > 0 { - conn.execute(&*rollback_ansi_transaction_sql(depth)).await?; - conn.transaction_depth = depth - 1; - } - - Ok(()) - }) + Box::pin(conn.worker.rollback()) } fn start_rollback(conn: &mut SqliteConnection) { - let depth = conn.transaction_depth; - - if depth > 0 { - let query = rollback_ansi_transaction_sql(depth); - let mut z_query = String::with_capacity(query.len() + 1); - z_query.push_str(&query); - z_query.push('\0'); - - unsafe { - // NOTE: this is a direct execution as a ROLLBACK is unlikely to block - // for any amount of time - let status = sqlite3_exec( - conn.handle.as_ptr(), - z_query.as_ptr() as _, - None, - ptr::null_mut(), - ptr::null_mut(), - ); - - if status != SQLITE_OK { - panic!( - "error occurred while dropping a transaction: {}", - SqliteError::new(conn.handle.as_ptr()) - ); - } - } - - conn.transaction_depth = depth - 1; - } + conn.worker.start_rollback().ok(); } } diff --git a/sqlx-core/src/sqlite/value.rs b/sqlx-core/src/sqlite/value.rs index 5a9dd87c42..f46041df11 100644 --- a/sqlx-core/src/sqlite/value.rs +++ b/sqlx-core/src/sqlite/value.rs @@ -10,19 +10,12 @@ use libsqlite3_sys::{ }; use crate::error::BoxDynError; -use crate::sqlite::statement::StatementHandle; use crate::sqlite::type_info::DataType; use crate::sqlite::{Sqlite, SqliteTypeInfo}; use crate::value::{Value, ValueRef}; use std::borrow::Cow; enum SqliteValueData<'r> { - Statement { - statement: &'r StatementHandle, - type_info: SqliteTypeInfo, - index: usize, - }, - Value(&'r SqliteValue), } @@ -33,64 +26,32 @@ impl<'r> SqliteValueRef<'r> { Self(SqliteValueData::Value(value)) } - pub(crate) fn statement( - statement: &'r StatementHandle, - type_info: SqliteTypeInfo, - index: usize, - ) -> Self { - Self(SqliteValueData::Statement { - statement, - type_info, - index, - }) - } - pub(super) fn int(&self) -> i32 { match self.0 { - SqliteValueData::Statement { - statement, index, .. - } => statement.column_int(index), - SqliteValueData::Value(v) => v.int(), } } pub(super) fn int64(&self) -> i64 { match self.0 { - SqliteValueData::Statement { - statement, index, .. - } => statement.column_int64(index), - SqliteValueData::Value(v) => v.int64(), } } pub(super) fn double(&self) -> f64 { match self.0 { - SqliteValueData::Statement { - statement, index, .. - } => statement.column_double(index), - SqliteValueData::Value(v) => v.double(), } } pub(super) fn blob(&self) -> &'r [u8] { match self.0 { - SqliteValueData::Statement { - statement, index, .. - } => statement.column_blob(index), - SqliteValueData::Value(v) => v.blob(), } } pub(super) fn text(&self) -> Result<&'r str, BoxDynError> { match self.0 { - SqliteValueData::Statement { - statement, index, .. - } => statement.column_text(index), - SqliteValueData::Value(v) => v.text(), } } @@ -101,12 +62,6 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> { fn to_owned(&self) -> SqliteValue { match self.0 { - SqliteValueData::Statement { - statement, - index, - ref type_info, - } => unsafe { SqliteValue::new(statement.column_value(index), type_info.clone()) }, - SqliteValueData::Value(v) => v.clone(), } } @@ -114,24 +69,11 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> { fn type_info(&self) -> Cow<'_, SqliteTypeInfo> { match self.0 { SqliteValueData::Value(v) => v.type_info(), - - SqliteValueData::Statement { - ref type_info, - statement, - index, - } => statement - .column_type_info_opt(index) - .map(Cow::Owned) - .unwrap_or(Cow::Borrowed(type_info)), } } fn is_null(&self) -> bool { match self.0 { - SqliteValueData::Statement { - statement, index, .. - } => statement.column_type(index) == SQLITE_NULL, - SqliteValueData::Value(v) => v.is_null(), } } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 12f1834e8c..dfa79c7bb8 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1,8 +1,10 @@ use futures::TryStreamExt; -use sqlx::sqlite::SqlitePoolOptions; +use rand::{Rng, SeedableRng}; +use rand_xoshiro::Xoshiro256PlusPlus; +use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use sqlx::{ - query, sqlite::Sqlite, sqlite::SqliteRow, Column, Connection, Executor, Row, SqliteConnection, - SqlitePool, Statement, TypeInfo, + query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, + SqliteConnection, SqlitePool, Statement, TypeInfo, }; use sqlx_test::new; @@ -389,7 +391,10 @@ SELECT id, text FROM _sqlx_test; async fn it_supports_collations() -> anyhow::Result<()> { let mut conn = new::().await?; - conn.create_collation("test_collation", |l, r| l.cmp(r).reverse())?; + // also tests `.lock_handle()` + conn.lock_handle() + .await? + .create_collation("test_collation", |l, r| l.cmp(r).reverse())?; let _ = conn .execute( @@ -592,3 +597,64 @@ async fn row_dropped_after_connection_doesnt_panic() { sqlx_rt::sleep(std::time::Duration::from_secs(1)).await; drop(books); } + +// note: to repro issue #1467 this should be run in release mode +#[sqlx_macros::test] +async fn issue_1467() -> anyhow::Result<()> { + let mut conn = SqliteConnectOptions::new() + .filename(":memory:") + .connect() + .await?; + + sqlx::query( + r#" + CREATE TABLE kv (k PRIMARY KEY, v); + CREATE INDEX idx_kv ON kv (v); + "#, + ) + .execute(&mut conn) + .await?; + + // Random seed: + let seed: [u8; 32] = rand::random(); + println!("RNG seed: {}", hex::encode(&seed)); + + // Pre-determined seed: + // let mut seed: [u8; 32] = [0u8; 32]; + // hex::decode_to_slice( + // "135234871d03fc0479e22f2f06395b6074761bac5fe7dcf205dbe01eef9f7794", + // &mut seed, + // )?; + + // reproducible RNG for testing + let mut rng = Xoshiro256PlusPlus::from_seed(seed); + + for i in 0..1_000_000 { + if i % 1_000 == 0 { + println!("{}", i); + } + let key = rng.gen_range(0..1_000); + let value = rng.gen_range(0..1_000); + let mut tx = conn.begin().await?; + + let exists = sqlx::query("SELECT 1 FROM kv WHERE k = ?") + .bind(key) + .fetch_optional(&mut tx) + .await?; + if exists.is_some() { + sqlx::query("UPDATE kv SET v = ? WHERE k = ?") + .bind(value) + .bind(key) + .execute(&mut tx) + .await?; + } else { + sqlx::query("INSERT INTO kv(k, v) VALUES (?, ?)") + .bind(key) + .bind(value) + .execute(&mut tx) + .await?; + } + tx.commit().await?; + } + Ok(()) +}