Skip to content

safety: pointer related bugfixes #229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions scylla-rust-wrapper/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,26 @@ use crate::value::CassCqlValue;
use crate::{argconv::*, value};
use std::convert::TryFrom;
use std::sync::Arc;
use std::sync::LazyLock;

// These constants help us to save an allocation in case user calls `cass_collection_new` (untyped collection).
static UNTYPED_LIST_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::List {
typ: None,
frozen: false,
static UNTYPED_LIST_TYPE: LazyLock<Arc<CassDataType>> = LazyLock::new(|| {
CassDataType::new_arced(CassDataTypeInner::List {
typ: None,
frozen: false,
})
});
static UNTYPED_SET_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::Set {
typ: None,
frozen: false,
static UNTYPED_SET_TYPE: LazyLock<Arc<CassDataType>> = LazyLock::new(|| {
CassDataType::new_arced(CassDataTypeInner::Set {
typ: None,
frozen: false,
})
});
static UNTYPED_MAP_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::Map {
typ: MapDataType::Untyped,
frozen: false,
static UNTYPED_MAP_TYPE: LazyLock<Arc<CassDataType>> = LazyLock::new(|| {
CassDataType::new_arced(CassDataTypeInner::Map {
typ: MapDataType::Untyped,
frozen: false,
})
});

#[derive(Clone)]
Expand Down Expand Up @@ -183,9 +190,9 @@ unsafe extern "C" fn cass_collection_data_type(
match &collection_ref.data_type {
Some(dt) => ArcFFI::as_ptr(dt),
None => match collection_ref.collection_type {
CassCollectionType::CASS_COLLECTION_TYPE_LIST => &UNTYPED_LIST_TYPE,
CassCollectionType::CASS_COLLECTION_TYPE_SET => &UNTYPED_SET_TYPE,
CassCollectionType::CASS_COLLECTION_TYPE_MAP => &UNTYPED_MAP_TYPE,
CassCollectionType::CASS_COLLECTION_TYPE_LIST => ArcFFI::as_ptr(&UNTYPED_LIST_TYPE),
CassCollectionType::CASS_COLLECTION_TYPE_SET => ArcFFI::as_ptr(&UNTYPED_SET_TYPE),
CassCollectionType::CASS_COLLECTION_TYPE_MAP => ArcFFI::as_ptr(&UNTYPED_MAP_TYPE),
// CassCollectionType is a C enum. Panic, if it's out of range.
_ => panic!(
"CassCollectionType enum value out of range: {}",
Expand Down Expand Up @@ -225,7 +232,10 @@ mod tests {
use crate::{
argconv::ArcFFI,
cass_error::CassError,
cass_types::{CassDataType, CassDataTypeInner, CassValueType, MapDataType},
cass_types::{
cass_data_type_add_sub_type, cass_data_type_free, cass_data_type_new, CassDataType,
CassDataTypeInner, CassValueType, MapDataType,
},
collection::{
cass_collection_append_double, cass_collection_append_float, cass_collection_free,
},
Expand All @@ -234,7 +244,8 @@ mod tests {

use super::{
cass_bool_t, cass_collection_append_bool, cass_collection_append_int16,
cass_collection_new, cass_collection_new_from_data_type, CassCollectionType,
cass_collection_data_type, cass_collection_new, cass_collection_new_from_data_type,
CassCollectionType,
};

#[test]
Expand Down Expand Up @@ -498,4 +509,24 @@ mod tests {
}
}
}

#[test]
fn regression_empty_collection_data_type_test() {
// This is a regression test that checks whether collections return
// an Arc-based pointer for their type, even if they are empty.
// Previously, they would return the pointer to static data, but not Arc allocated.
unsafe {
let empty_list = cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_LIST, 2);

// This would previously return a non Arc-based pointer.
let empty_list_dt = cass_collection_data_type(empty_list);

let empty_set_dt = cass_data_type_new(CassValueType::CASS_VALUE_TYPE_SET);
// This will try to increment the reference count of `empty_list_dt`.
// Previously, this would fail, because `empty_list_dt` did not originate from an Arc allocation.
cass_data_type_add_sub_type(empty_set_dt, empty_list_dt);

cass_data_type_free(empty_set_dt)
}
}
}
15 changes: 11 additions & 4 deletions scylla-rust-wrapper/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl RefFFI for CassMaterializedViewMeta {}

pub struct CassColumnMeta {
pub name: String,
pub column_type: CassDataType,
pub column_type: Arc<CassDataType>,
pub column_kind: CassColumnType,
}

Expand All @@ -66,7 +66,7 @@ pub fn create_table_metadata(table_name: &str, table_metadata: &Table) -> CassTa
.for_each(|(column_name, column_metadata)| {
let cass_column_meta = CassColumnMeta {
name: column_name.clone(),
column_type: get_column_type(&column_metadata.typ),
column_type: Arc::new(get_column_type(&column_metadata.typ)),
column_kind: match column_metadata.kind {
ColumnKind::Regular => CassColumnType::CASS_COLUMN_TYPE_REGULAR,
ColumnKind::Static => CassColumnType::CASS_COLUMN_TYPE_STATIC,
Expand Down Expand Up @@ -299,7 +299,7 @@ pub unsafe extern "C" fn cass_column_meta_data_type(
column_meta: *const CassColumnMeta,
) -> *const CassDataType {
let column_meta = RefFFI::as_ref(column_meta);
&column_meta.column_type as *const CassDataType
ArcFFI::as_ptr(&column_meta.column_type)
}

#[no_mangle]
Expand Down Expand Up @@ -427,7 +427,14 @@ pub unsafe extern "C" fn cass_materialized_view_meta_base_table(
view_meta: *const CassMaterializedViewMeta,
) -> *const CassTableMeta {
let view_meta = RefFFI::as_ref(view_meta);
view_meta.base_table.as_ptr()

match view_meta.base_table.upgrade() {
Some(arc) => RefFFI::as_ptr(&arc),
None => {
tracing::error!("Failed to upgrade a weak reference to table metadata from materialized view metadata! This is a driver bug!");
std::ptr::null()
}
}
}

#[no_mangle]
Expand Down
22 changes: 9 additions & 13 deletions scylla-rust-wrapper/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,7 @@ impl CassSessionInner {
}

fn connect(
// This reference is 'static because this is the only was of assuring the borrow checker
// that holding it in our returned future is sound. Ideally, we would prefer to have
// the returned future's lifetime constrained by real lifetime of the session's RwLock,
// but this is impossible to be guaranteed due to C/Rust cross-language barrier.
session_opt: &'static RwLock<Option<CassSessionInner>>,
session_opt: Arc<RwLock<Option<CassSessionInner>>>,
cluster: &CassCluster,
keyspace: Option<String>,
) -> *mut CassFuture {
Expand All @@ -94,7 +90,7 @@ impl CassSessionInner {
}

async fn connect_fut(
session_opt: &RwLock<Option<CassSessionInner>>,
session_opt: Arc<RwLock<Option<CassSessionInner>>>,
session_builder_fut: impl Future<Output = SessionBuilder>,
exec_profile_builder_map: HashMap<ExecProfileName, CassExecProfile>,
client_id: uuid::Uuid,
Expand Down Expand Up @@ -154,7 +150,7 @@ pub unsafe extern "C" fn cass_session_connect(
session_raw: *mut CassSession,
cluster_raw: *const CassCluster,
) -> *mut CassFuture {
let session_opt = ArcFFI::as_ref(session_raw);
let session_opt = ArcFFI::cloned_from_ptr(session_raw);
let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw);

CassSessionInner::connect(session_opt, cluster, None)
Expand All @@ -176,7 +172,7 @@ pub unsafe extern "C" fn cass_session_connect_keyspace_n(
keyspace: *const c_char,
keyspace_length: size_t,
) -> *mut CassFuture {
let session_opt = ArcFFI::as_ref(session_raw);
let session_opt = ArcFFI::cloned_from_ptr(session_raw);
let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw);
let keyspace = ptr_to_cstr_n(keyspace, keyspace_length).map(ToOwned::to_owned);

Expand All @@ -188,7 +184,7 @@ pub unsafe extern "C" fn cass_session_execute_batch(
session_raw: *mut CassSession,
batch_raw: *const CassBatch,
) -> *mut CassFuture {
let session_opt = ArcFFI::as_ref(session_raw);
let session_opt = ArcFFI::cloned_from_ptr(session_raw);
let batch_from_raw = BoxFFI::as_ref(batch_raw);
let mut state = batch_from_raw.state.clone();
let request_timeout_ms = batch_from_raw.batch_request_timeout_ms;
Expand Down Expand Up @@ -254,7 +250,7 @@ pub unsafe extern "C" fn cass_session_execute(
session_raw: *mut CassSession,
statement_raw: *const CassStatement,
) -> *mut CassFuture {
let session_opt = ArcFFI::as_ref(session_raw);
let session_opt = ArcFFI::cloned_from_ptr(session_raw);

// DO NOT refer to `statement_opt` inside the async block, as I've done just to face a segfault.
let statement_opt = BoxFFI::as_ref(statement_raw);
Expand Down Expand Up @@ -389,7 +385,7 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing(
cass_session: *mut CassSession,
statement: *const CassStatement,
) -> *mut CassFuture {
let session = ArcFFI::as_ref(cass_session);
let session = ArcFFI::cloned_from_ptr(cass_session);
let cass_statement = BoxFFI::as_ref(statement);
let statement = cass_statement.statement.clone();

Expand Down Expand Up @@ -441,7 +437,7 @@ pub unsafe extern "C" fn cass_session_prepare_n(
// There is a test for this: `NullStringApiArgsTest.Integration_Cassandra_PrepareNullQuery`.
.unwrap_or_default();
let query = Statement::new(query_str.to_string());
let cass_session = ArcFFI::as_ref(cass_session_raw);
let cass_session = ArcFFI::cloned_from_ptr(cass_session_raw);

CassFuture::make_raw(async move {
let session_guard = cass_session.read().await;
Expand Down Expand Up @@ -474,7 +470,7 @@ pub unsafe extern "C" fn cass_session_free(session_raw: *mut CassSession) {

#[no_mangle]
pub unsafe extern "C" fn cass_session_close(session: *mut CassSession) -> *mut CassFuture {
let session_opt = ArcFFI::as_ref(session);
let session_opt = ArcFFI::cloned_from_ptr(session);

CassFuture::make_raw(async move {
let mut session_guard = session_opt.write().await;
Expand Down
35 changes: 33 additions & 2 deletions scylla-rust-wrapper/src/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ use crate::types::*;
use crate::value;
use crate::value::CassCqlValue;
use std::sync::Arc;
use std::sync::LazyLock;

static UNTYPED_TUPLE_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::Tuple(Vec::new()));
static UNTYPED_TUPLE_TYPE: LazyLock<Arc<CassDataType>> =
LazyLock::new(|| CassDataType::new_arced(CassDataTypeInner::Tuple(Vec::new())));

#[derive(Clone)]
pub struct CassTuple {
Expand Down Expand Up @@ -92,7 +94,7 @@ unsafe extern "C" fn cass_tuple_free(tuple: *mut CassTuple) {
unsafe extern "C" fn cass_tuple_data_type(tuple: *const CassTuple) -> *const CassDataType {
match &BoxFFI::as_ref(tuple).data_type {
Some(t) => ArcFFI::as_ptr(t),
None => &UNTYPED_TUPLE_TYPE,
None => ArcFFI::as_ptr(&UNTYPED_TUPLE_TYPE),
}
}

Expand All @@ -116,3 +118,32 @@ make_binders!(decimal, cass_tuple_set_decimal);
make_binders!(collection, cass_tuple_set_collection);
make_binders!(tuple, cass_tuple_set_tuple);
make_binders!(user_type, cass_tuple_set_user_type);

#[cfg(test)]
mod tests {
use crate::cass_types::{
cass_data_type_add_sub_type, cass_data_type_free, cass_data_type_new, CassValueType,
};

use super::{cass_tuple_data_type, cass_tuple_new};

#[test]
fn regression_empty_tuple_data_type_test() {
// This is a regression test that checks whether tuples return
// an Arc-based pointer for their type, even if they are empty.
// Previously, they would return the pointer to static data, but not Arc allocated.
unsafe {
let empty_tuple = cass_tuple_new(2);

// This would previously return a non Arc-based pointer.
let empty_tuple_dt = cass_tuple_data_type(empty_tuple);

let empty_set_dt = cass_data_type_new(CassValueType::CASS_VALUE_TYPE_SET);
// This will try to increment the reference count of `empty_tuple_dt`.
// Previously, this would fail, because `empty_tuple_dt` did not originate from an Arc allocation.
cass_data_type_add_sub_type(empty_set_dt, empty_tuple_dt);

cass_data_type_free(empty_set_dt)
}
}
}