diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index dd6e0f2b..77628659 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -6,19 +6,26 @@ use crate::value::CassCqlValue; use crate::{argconv::*, value}; use std::convert::TryFrom; use std::sync::Arc; +use std::sync::LazyLock; // These constants help us to save an allocation in case user calls `cass_collection_new` (untyped collection). -static UNTYPED_LIST_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::List { - typ: None, - frozen: false, +static UNTYPED_LIST_TYPE: LazyLock> = LazyLock::new(|| { + CassDataType::new_arced(CassDataTypeInner::List { + typ: None, + frozen: false, + }) }); -static UNTYPED_SET_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::Set { - typ: None, - frozen: false, +static UNTYPED_SET_TYPE: LazyLock> = LazyLock::new(|| { + CassDataType::new_arced(CassDataTypeInner::Set { + typ: None, + frozen: false, + }) }); -static UNTYPED_MAP_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::Map { - typ: MapDataType::Untyped, - frozen: false, +static UNTYPED_MAP_TYPE: LazyLock> = LazyLock::new(|| { + CassDataType::new_arced(CassDataTypeInner::Map { + typ: MapDataType::Untyped, + frozen: false, + }) }); #[derive(Clone)] @@ -183,9 +190,9 @@ unsafe extern "C" fn cass_collection_data_type( match &collection_ref.data_type { Some(dt) => ArcFFI::as_ptr(dt), None => match collection_ref.collection_type { - CassCollectionType::CASS_COLLECTION_TYPE_LIST => &UNTYPED_LIST_TYPE, - CassCollectionType::CASS_COLLECTION_TYPE_SET => &UNTYPED_SET_TYPE, - CassCollectionType::CASS_COLLECTION_TYPE_MAP => &UNTYPED_MAP_TYPE, + CassCollectionType::CASS_COLLECTION_TYPE_LIST => ArcFFI::as_ptr(&UNTYPED_LIST_TYPE), + CassCollectionType::CASS_COLLECTION_TYPE_SET => ArcFFI::as_ptr(&UNTYPED_SET_TYPE), + CassCollectionType::CASS_COLLECTION_TYPE_MAP => ArcFFI::as_ptr(&UNTYPED_MAP_TYPE), // CassCollectionType is a C enum. Panic, if it's out of range. _ => panic!( "CassCollectionType enum value out of range: {}", @@ -225,7 +232,10 @@ mod tests { use crate::{ argconv::ArcFFI, cass_error::CassError, - cass_types::{CassDataType, CassDataTypeInner, CassValueType, MapDataType}, + cass_types::{ + cass_data_type_add_sub_type, cass_data_type_free, cass_data_type_new, CassDataType, + CassDataTypeInner, CassValueType, MapDataType, + }, collection::{ cass_collection_append_double, cass_collection_append_float, cass_collection_free, }, @@ -234,7 +244,8 @@ mod tests { use super::{ cass_bool_t, cass_collection_append_bool, cass_collection_append_int16, - cass_collection_new, cass_collection_new_from_data_type, CassCollectionType, + cass_collection_data_type, cass_collection_new, cass_collection_new_from_data_type, + CassCollectionType, }; #[test] @@ -498,4 +509,24 @@ mod tests { } } } + + #[test] + fn regression_empty_collection_data_type_test() { + // This is a regression test that checks whether collections return + // an Arc-based pointer for their type, even if they are empty. + // Previously, they would return the pointer to static data, but not Arc allocated. + unsafe { + let empty_list = cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_LIST, 2); + + // This would previously return a non Arc-based pointer. + let empty_list_dt = cass_collection_data_type(empty_list); + + let empty_set_dt = cass_data_type_new(CassValueType::CASS_VALUE_TYPE_SET); + // This will try to increment the reference count of `empty_list_dt`. + // Previously, this would fail, because `empty_list_dt` did not originate from an Arc allocation. + cass_data_type_add_sub_type(empty_set_dt, empty_list_dt); + + cass_data_type_free(empty_set_dt) + } + } } diff --git a/scylla-rust-wrapper/src/metadata.rs b/scylla-rust-wrapper/src/metadata.rs index 6ee0708e..d9a2af0f 100644 --- a/scylla-rust-wrapper/src/metadata.rs +++ b/scylla-rust-wrapper/src/metadata.rs @@ -51,7 +51,7 @@ impl RefFFI for CassMaterializedViewMeta {} pub struct CassColumnMeta { pub name: String, - pub column_type: CassDataType, + pub column_type: Arc, pub column_kind: CassColumnType, } @@ -66,7 +66,7 @@ pub fn create_table_metadata(table_name: &str, table_metadata: &Table) -> CassTa .for_each(|(column_name, column_metadata)| { let cass_column_meta = CassColumnMeta { name: column_name.clone(), - column_type: get_column_type(&column_metadata.typ), + column_type: Arc::new(get_column_type(&column_metadata.typ)), column_kind: match column_metadata.kind { ColumnKind::Regular => CassColumnType::CASS_COLUMN_TYPE_REGULAR, ColumnKind::Static => CassColumnType::CASS_COLUMN_TYPE_STATIC, @@ -299,7 +299,7 @@ pub unsafe extern "C" fn cass_column_meta_data_type( column_meta: *const CassColumnMeta, ) -> *const CassDataType { let column_meta = RefFFI::as_ref(column_meta); - &column_meta.column_type as *const CassDataType + ArcFFI::as_ptr(&column_meta.column_type) } #[no_mangle] @@ -427,7 +427,14 @@ pub unsafe extern "C" fn cass_materialized_view_meta_base_table( view_meta: *const CassMaterializedViewMeta, ) -> *const CassTableMeta { let view_meta = RefFFI::as_ref(view_meta); - view_meta.base_table.as_ptr() + + match view_meta.base_table.upgrade() { + Some(arc) => RefFFI::as_ptr(&arc), + None => { + tracing::error!("Failed to upgrade a weak reference to table metadata from materialized view metadata! This is a driver bug!"); + std::ptr::null() + } + } } #[no_mangle] diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index aade311e..a206ac6a 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -70,11 +70,7 @@ impl CassSessionInner { } fn connect( - // This reference is 'static because this is the only was of assuring the borrow checker - // that holding it in our returned future is sound. Ideally, we would prefer to have - // the returned future's lifetime constrained by real lifetime of the session's RwLock, - // but this is impossible to be guaranteed due to C/Rust cross-language barrier. - session_opt: &'static RwLock>, + session_opt: Arc>>, cluster: &CassCluster, keyspace: Option, ) -> *mut CassFuture { @@ -94,7 +90,7 @@ impl CassSessionInner { } async fn connect_fut( - session_opt: &RwLock>, + session_opt: Arc>>, session_builder_fut: impl Future, exec_profile_builder_map: HashMap, client_id: uuid::Uuid, @@ -154,7 +150,7 @@ pub unsafe extern "C" fn cass_session_connect( session_raw: *mut CassSession, cluster_raw: *const CassCluster, ) -> *mut CassFuture { - let session_opt = ArcFFI::as_ref(session_raw); + let session_opt = ArcFFI::cloned_from_ptr(session_raw); let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw); CassSessionInner::connect(session_opt, cluster, None) @@ -176,7 +172,7 @@ pub unsafe extern "C" fn cass_session_connect_keyspace_n( keyspace: *const c_char, keyspace_length: size_t, ) -> *mut CassFuture { - let session_opt = ArcFFI::as_ref(session_raw); + let session_opt = ArcFFI::cloned_from_ptr(session_raw); let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw); let keyspace = ptr_to_cstr_n(keyspace, keyspace_length).map(ToOwned::to_owned); @@ -188,7 +184,7 @@ pub unsafe extern "C" fn cass_session_execute_batch( session_raw: *mut CassSession, batch_raw: *const CassBatch, ) -> *mut CassFuture { - let session_opt = ArcFFI::as_ref(session_raw); + let session_opt = ArcFFI::cloned_from_ptr(session_raw); let batch_from_raw = BoxFFI::as_ref(batch_raw); let mut state = batch_from_raw.state.clone(); let request_timeout_ms = batch_from_raw.batch_request_timeout_ms; @@ -254,7 +250,7 @@ pub unsafe extern "C" fn cass_session_execute( session_raw: *mut CassSession, statement_raw: *const CassStatement, ) -> *mut CassFuture { - let session_opt = ArcFFI::as_ref(session_raw); + let session_opt = ArcFFI::cloned_from_ptr(session_raw); // DO NOT refer to `statement_opt` inside the async block, as I've done just to face a segfault. let statement_opt = BoxFFI::as_ref(statement_raw); @@ -389,7 +385,7 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing( cass_session: *mut CassSession, statement: *const CassStatement, ) -> *mut CassFuture { - let session = ArcFFI::as_ref(cass_session); + let session = ArcFFI::cloned_from_ptr(cass_session); let cass_statement = BoxFFI::as_ref(statement); let statement = cass_statement.statement.clone(); @@ -441,7 +437,7 @@ pub unsafe extern "C" fn cass_session_prepare_n( // There is a test for this: `NullStringApiArgsTest.Integration_Cassandra_PrepareNullQuery`. .unwrap_or_default(); let query = Statement::new(query_str.to_string()); - let cass_session = ArcFFI::as_ref(cass_session_raw); + let cass_session = ArcFFI::cloned_from_ptr(cass_session_raw); CassFuture::make_raw(async move { let session_guard = cass_session.read().await; @@ -474,7 +470,7 @@ pub unsafe extern "C" fn cass_session_free(session_raw: *mut CassSession) { #[no_mangle] pub unsafe extern "C" fn cass_session_close(session: *mut CassSession) -> *mut CassFuture { - let session_opt = ArcFFI::as_ref(session); + let session_opt = ArcFFI::cloned_from_ptr(session); CassFuture::make_raw(async move { let mut session_guard = session_opt.write().await; diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index 93602c63..cdf5cfa5 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -6,8 +6,10 @@ use crate::types::*; use crate::value; use crate::value::CassCqlValue; use std::sync::Arc; +use std::sync::LazyLock; -static UNTYPED_TUPLE_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::Tuple(Vec::new())); +static UNTYPED_TUPLE_TYPE: LazyLock> = + LazyLock::new(|| CassDataType::new_arced(CassDataTypeInner::Tuple(Vec::new()))); #[derive(Clone)] pub struct CassTuple { @@ -92,7 +94,7 @@ unsafe extern "C" fn cass_tuple_free(tuple: *mut CassTuple) { unsafe extern "C" fn cass_tuple_data_type(tuple: *const CassTuple) -> *const CassDataType { match &BoxFFI::as_ref(tuple).data_type { Some(t) => ArcFFI::as_ptr(t), - None => &UNTYPED_TUPLE_TYPE, + None => ArcFFI::as_ptr(&UNTYPED_TUPLE_TYPE), } } @@ -116,3 +118,32 @@ make_binders!(decimal, cass_tuple_set_decimal); make_binders!(collection, cass_tuple_set_collection); make_binders!(tuple, cass_tuple_set_tuple); make_binders!(user_type, cass_tuple_set_user_type); + +#[cfg(test)] +mod tests { + use crate::cass_types::{ + cass_data_type_add_sub_type, cass_data_type_free, cass_data_type_new, CassValueType, + }; + + use super::{cass_tuple_data_type, cass_tuple_new}; + + #[test] + fn regression_empty_tuple_data_type_test() { + // This is a regression test that checks whether tuples return + // an Arc-based pointer for their type, even if they are empty. + // Previously, they would return the pointer to static data, but not Arc allocated. + unsafe { + let empty_tuple = cass_tuple_new(2); + + // This would previously return a non Arc-based pointer. + let empty_tuple_dt = cass_tuple_data_type(empty_tuple); + + let empty_set_dt = cass_data_type_new(CassValueType::CASS_VALUE_TYPE_SET); + // This will try to increment the reference count of `empty_tuple_dt`. + // Previously, this would fail, because `empty_tuple_dt` did not originate from an Arc allocation. + cass_data_type_add_sub_type(empty_set_dt, empty_tuple_dt); + + cass_data_type_free(empty_set_dt) + } + } +}