diff --git a/sqlx-core/src/sqlite/connection/collation.rs b/sqlx-core/src/sqlite/connection/collation.rs new file mode 100644 index 0000000000..353fa7252c --- /dev/null +++ b/sqlx-core/src/sqlite/connection/collation.rs @@ -0,0 +1,74 @@ +use std::cmp::Ordering; +use std::ffi::CString; +use std::os::raw::{c_char, c_int, c_void}; +use std::slice; +use std::str::from_utf8_unchecked; + +use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8}; + +use crate::error::Error; +use crate::sqlite::connection::handle::ConnectionHandle; +use crate::sqlite::SqliteError; + +unsafe extern "C" fn free_boxed_value(p: *mut c_void) { + drop(Box::from_raw(p as *mut T)); +} + +pub(crate) fn create_collation( + handle: &ConnectionHandle, + name: &str, + compare: F, +) -> Result<(), Error> +where + F: Fn(&str, &str) -> Ordering + Send + Sync + 'static, +{ + unsafe extern "C" fn call_boxed_closure( + arg1: *mut c_void, + arg2: c_int, + arg3: *const c_void, + arg4: c_int, + arg5: *const c_void, + ) -> c_int + where + C: Fn(&str, &str) -> Ordering, + { + let boxed_f: *mut C = arg1 as *mut C; + debug_assert!(!boxed_f.is_null()); + let s1 = { + let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize); + from_utf8_unchecked(c_slice) + }; + let s2 = { + let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize); + from_utf8_unchecked(c_slice) + }; + let t = (*boxed_f)(s1, s2); + + match t { + Ordering::Less => -1, + Ordering::Equal => 0, + Ordering::Greater => 1, + } + } + + let boxed_f: *mut F = Box::into_raw(Box::new(compare)); + let c_name = + CString::new(name).map_err(|_| err_protocol!("invalid collation name: {}", name))?; + let flags = SQLITE_UTF8; + let r = unsafe { + sqlite3_create_collation_v2( + handle.as_ptr(), + c_name.as_ptr(), + flags, + boxed_f as *mut c_void, + Some(call_boxed_closure::), + Some(free_boxed_value::), + ) + }; + + if r == SQLITE_OK { + Ok(()) + } else { + Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) + } +} diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index 31ac8ac26e..12c7e713b1 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -1,3 +1,4 @@ +use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -15,6 +16,7 @@ use crate::sqlite::connection::establish::establish; use crate::sqlite::statement::{SqliteStatement, StatementWorker}; use crate::sqlite::{Sqlite, SqliteConnectOptions}; +mod collation; mod describe; mod establish; mod executor; @@ -43,6 +45,14 @@ impl SqliteConnection { pub fn as_raw_handle(&mut self) -> *mut sqlite3 { self.handle.as_ptr() } + + pub fn create_collation( + &mut self, + name: &str, + compare: impl Fn(&str, &str) -> Ordering + Send + Sync + 'static, + ) -> Result<(), Error> { + collation::create_collation(&self.handle, name, compare) + } } impl Debug for SqliteConnection { diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 3a2cf45360..daa79e3b1f 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1,6 +1,8 @@ use futures::TryStreamExt; -use sqlx::sqlite::{Sqlite, SqliteConnection, SqlitePool, SqliteRow}; -use sqlx::{query, Connect, Connection, Executor, Row}; +use sqlx::{ + query, sqlite::Sqlite, sqlite::SqliteRow, Connect, Connection, Executor, Row, SqliteConnection, + SqlitePool, +}; use sqlx_test::new; #[sqlx_macros::test] @@ -303,6 +305,39 @@ SELECT id, text FROM _sqlx_test; Ok(()) } +#[sqlx_macros::test] +async fn it_supports_collations() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.create_collation("test_collation", |l, r| l.cmp(r).reverse())?; + + let _ = conn + .execute( + r#" +CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL COLLATE test_collation) + "#, + ) + .await?; + + sqlx::query("INSERT INTO users (name) VALUES (?)") + .bind("a") + .execute(&mut conn) + .await?; + sqlx::query("INSERT INTO users (name) VALUES (?)") + .bind("b") + .execute(&mut conn) + .await?; + + let row: SqliteRow = conn + .fetch_one("SELECT name FROM users ORDER BY name ASC") + .await?; + let name: &str = row.try_get(0)?; + + assert_eq!(name, "b"); + + Ok(()) +} + #[sqlx_macros::test] async fn it_caches_statements() -> anyhow::Result<()> { let mut conn = new::().await?;