From 50f432294616e644bf0fd673c445797a93589ac0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20Bary=C5=82a?= Date: Fri, 6 Jan 2023 11:37:46 +0100 Subject: [PATCH 1/5] Put content of CassDataType into UnsafeCell Previously `CassDataType` was just an enum, held inside `Arc`. User was given a pointer to `CassDataType` using `Arc::as_ptr` or `Arc::into_ptr`. There are however some functions that mutate the data - and they were given the very same pointers. Current code was most likely sound - but I'm not completely sure, Rust reference is very confusing in this aspect. It was however very confusing - when a programmer reads or writes a function that that *mut CassDataType it is not obivious that this data lies inside Arc and so has shared ownership. To make this more explicit this commit puts `CassDataType` inside UnsafeCell. Now each access needs to use `.get_unchecked()` and `.get_mut_unchecked()` methods and an unsafe block / function, so it will be easier to spot aliasing ^ mutability problems in the future. In the future we can use `Arc::get_mut_unchecked()` for this purpose, but it's not yet stabilised. --- scylla-rust-wrapper/src/cass_types.rs | 354 ++++++++++++++---------- scylla-rust-wrapper/src/collection.rs | 68 +++-- scylla-rust-wrapper/src/query_result.rs | 54 ++-- scylla-rust-wrapper/src/session.rs | 4 +- scylla-rust-wrapper/src/tuple.rs | 11 +- scylla-rust-wrapper/src/user_type.rs | 30 +- scylla-rust-wrapper/src/value.rs | 226 ++++++++------- 7 files changed, 433 insertions(+), 314 deletions(-) diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index cae85d42..38197455 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -5,6 +5,7 @@ use scylla::batch::BatchType; use scylla::frame::response::result::ColumnType; use scylla::frame::types::{Consistency, SerialConsistency}; use scylla::transport::topology::{CollectionType, CqlType, NativeType, UserDefinedType}; +use std::cell::UnsafeCell; use std::collections::HashMap; use std::convert::TryFrom; use std::os::raw::c_char; @@ -112,7 +113,12 @@ impl UDTDataType { return false; } // Compare field types. - if !field.1.typecheck_equals(&other_field.1) { + if unsafe { + !field + .1 + .get_unchecked() + .typecheck_equals(other_field.1.get_unchecked()) + } { return false; } } @@ -145,7 +151,7 @@ pub struct CassColumnSpec { } #[derive(Clone, Debug, PartialEq, Eq)] -pub enum CassDataType { +pub enum CassDataTypeInner { Value(CassValueType), UDT(UDTDataType), List { @@ -167,20 +173,21 @@ pub enum CassDataType { Custom(String), } -impl CassDataType { +impl CassDataTypeInner { /// Checks for equality during typechecks. /// /// This takes into account the fact that tuples/collections may be untyped. - pub fn typecheck_equals(&self, other: &CassDataType) -> bool { + pub fn typecheck_equals(&self, other: &CassDataTypeInner) -> bool { match self { - CassDataType::Value(t) => *t == other.get_value_type(), - CassDataType::UDT(udt) => match other { - CassDataType::UDT(other_udt) => udt.typecheck_equals(other_udt), + CassDataTypeInner::Value(t) => *t == other.get_value_type(), + CassDataTypeInner::UDT(udt) => match other { + CassDataTypeInner::UDT(other_udt) => udt.typecheck_equals(other_udt), _ => false, }, - CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => match other { - CassDataType::List { typ: other_typ, .. } - | CassDataType::Set { typ: other_typ, .. } => { + CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => match other + { + CassDataTypeInner::List { typ: other_typ, .. } + | CassDataTypeInner::Set { typ: other_typ, .. } => { // If one of them is list, and the other is set, fail the typecheck. if self.get_value_type() != other.get_value_type() { return false; @@ -188,13 +195,16 @@ impl CassDataType { match (typ, other_typ) { // One of them is untyped, skip the typecheck for subtype. (None, _) | (_, None) => true, - (Some(typ), Some(other_typ)) => typ.typecheck_equals(other_typ), + (Some(typ), Some(other_typ)) => unsafe { + typ.get_unchecked() + .typecheck_equals(other_typ.get_unchecked()) + }, } } _ => false, }, - CassDataType::Map { typ: t, .. } => match other { - CassDataType::Map { typ: t_other, .. } => match (t, t_other) { + CassDataTypeInner::Map { typ: t, .. } => match other { + CassDataTypeInner::Map { typ: t_other, .. } => match (t, t_other) { // See https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L218 // In cpp-driver the types are held in a vector. // The logic is following: @@ -204,17 +214,22 @@ impl CassDataType { (_, MapDataType::Untyped) => true, // Otherwise, the vectors should have equal length and we perform the typecheck for subtypes. - (MapDataType::Key(k), MapDataType::Key(k_other)) => k.typecheck_equals(k_other), + (MapDataType::Key(k), MapDataType::Key(k_other)) => unsafe { + k.get_unchecked().typecheck_equals(k_other.get_unchecked()) + }, ( MapDataType::KeyAndValue(k, v), MapDataType::KeyAndValue(k_other, v_other), - ) => k.typecheck_equals(k_other) && v.typecheck_equals(v_other), + ) => unsafe { + k.get_unchecked().typecheck_equals(k_other.get_unchecked()) + && v.get_unchecked().typecheck_equals(v_other.get_unchecked()) + }, _ => false, }, _ => false, }, - CassDataType::Tuple(sub) => match other { - CassDataType::Tuple(other_sub) => { + CassDataTypeInner::Tuple(sub) => match other { + CassDataTypeInner::Tuple(other_sub) => { // If either of tuples is untyped, skip the typecheck for subtypes. if sub.is_empty() || other_sub.is_empty() { return true; @@ -226,17 +241,53 @@ impl CassDataType { } sub.iter() .zip(other_sub.iter()) - .all(|(typ, other_typ)| typ.typecheck_equals(other_typ)) + .all(|(typ, other_typ)| unsafe { + typ.get_unchecked() + .typecheck_equals(other_typ.get_unchecked()) + }) } _ => false, }, - CassDataType::Custom(_) => { + CassDataTypeInner::Custom(_) => { unimplemented!("Cpp-rust-driver does not support custom types!") } } } } +#[derive(Debug)] +#[repr(transparent)] +pub struct CassDataType(UnsafeCell); + +/// PartialEq and Eq for test purposes. +impl PartialEq for CassDataType { + fn eq(&self, other: &Self) -> bool { + unsafe { self.get_unchecked() == other.get_unchecked() } + } +} +impl Eq for CassDataType {} + +unsafe impl Sync for CassDataType {} + +impl CassDataType { + pub unsafe fn get_unchecked(&self) -> &CassDataTypeInner { + &*self.0.get() + } + + #[allow(clippy::mut_from_ref)] + pub unsafe fn get_mut_unchecked(&self) -> &mut CassDataTypeInner { + &mut *self.0.get() + } + + pub const fn new(inner: CassDataTypeInner) -> CassDataType { + CassDataType(UnsafeCell::new(inner)) + } + + pub fn new_arced(inner: CassDataTypeInner) -> Arc { + Arc::new(CassDataType(UnsafeCell::new(inner))) + } +} + impl From for CassValueType { fn from(native_type: NativeType) -> CassValueType { match native_type { @@ -269,10 +320,10 @@ pub fn get_column_type_from_cql_type( user_defined_types: &HashMap>, keyspace_name: &str, ) -> CassDataType { - match cql_type { - CqlType::Native(native) => CassDataType::Value(native.clone().into()), + let inner = match cql_type { + CqlType::Native(native) => CassDataTypeInner::Value(native.clone().into()), CqlType::Collection { type_, frozen } => match type_ { - CollectionType::List(list) => CassDataType::List { + CollectionType::List(list) => CassDataTypeInner::List { typ: Some(Arc::new(get_column_type_from_cql_type( list, user_defined_types, @@ -280,7 +331,7 @@ pub fn get_column_type_from_cql_type( ))), frozen: *frozen, }, - CollectionType::Map(key, value) => CassDataType::Map { + CollectionType::Map(key, value) => CassDataTypeInner::Map { typ: MapDataType::KeyAndValue( Arc::new(get_column_type_from_cql_type( key, @@ -295,7 +346,7 @@ pub fn get_column_type_from_cql_type( ), frozen: *frozen, }, - CollectionType::Set(set) => CassDataType::Set { + CollectionType::Set(set) => CassDataTypeInner::Set { typ: Some(Arc::new(get_column_type_from_cql_type( set, user_defined_types, @@ -304,7 +355,7 @@ pub fn get_column_type_from_cql_type( frozen: *frozen, }, }, - CqlType::Tuple(tuple) => CassDataType::Tuple( + CqlType::Tuple(tuple) => CassDataTypeInner::Tuple( tuple .iter() .map(|field_type| { @@ -321,38 +372,40 @@ pub fn get_column_type_from_cql_type( Ok(resolved) => &resolved.name, Err(not_resolved) => ¬_resolved.name, }; - CassDataType::UDT(UDTDataType::create_with_params( + CassDataTypeInner::UDT(UDTDataType::create_with_params( user_defined_types, keyspace_name, name, *frozen, )) } - } + }; + + CassDataType::new(inner) } -impl CassDataType { +impl CassDataTypeInner { fn get_sub_data_type(&self, index: usize) -> Option<&Arc> { match self { - CassDataType::UDT(udt_data_type) => { + CassDataTypeInner::UDT(udt_data_type) => { udt_data_type.field_types.get(index).map(|(_, b)| b) } - CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => { + CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => { if index > 0 { None } else { typ.as_ref() } } - CassDataType::Map { + CassDataTypeInner::Map { typ: MapDataType::Untyped, .. } => None, - CassDataType::Map { + CassDataTypeInner::Map { typ: MapDataType::Key(k), .. } => (index == 0).then_some(k), - CassDataType::Map { + CassDataTypeInner::Map { typ: MapDataType::KeyAndValue(k, v), .. } => match index { @@ -360,45 +413,45 @@ impl CassDataType { 1 => Some(v), _ => None, }, - CassDataType::Tuple(v) => v.get(index), + CassDataTypeInner::Tuple(v) => v.get(index), _ => None, } } fn add_sub_data_type(&mut self, sub_type: Arc) -> Result<(), CassError> { match self { - CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => match typ { + CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => match typ { Some(_) => Err(CassError::CASS_ERROR_LIB_BAD_PARAMS), None => { *typ = Some(sub_type); Ok(()) } }, - CassDataType::Map { + CassDataTypeInner::Map { typ: MapDataType::KeyAndValue(_, _), .. } => Err(CassError::CASS_ERROR_LIB_BAD_PARAMS), - CassDataType::Map { + CassDataTypeInner::Map { typ: MapDataType::Key(k), frozen, } => { - *self = CassDataType::Map { + *self = CassDataTypeInner::Map { typ: MapDataType::KeyAndValue(k.clone(), sub_type), frozen: *frozen, }; Ok(()) } - CassDataType::Map { + CassDataTypeInner::Map { typ: MapDataType::Untyped, frozen, } => { - *self = CassDataType::Map { + *self = CassDataTypeInner::Map { typ: MapDataType::Key(sub_type), frozen: *frozen, }; Ok(()) } - CassDataType::Tuple(types) => { + CassDataTypeInner::Tuple(types) => { types.push(sub_type); Ok(()) } @@ -408,53 +461,53 @@ impl CassDataType { pub fn get_udt_type(&self) -> &UDTDataType { match self { - CassDataType::UDT(udt) => udt, + CassDataTypeInner::UDT(udt) => udt, _ => panic!("Can get UDT out of non-UDT data type"), } } pub fn get_value_type(&self) -> CassValueType { match &self { - CassDataType::Value(value_data_type) => *value_data_type, - CassDataType::UDT { .. } => CassValueType::CASS_VALUE_TYPE_UDT, - CassDataType::List { .. } => CassValueType::CASS_VALUE_TYPE_LIST, - CassDataType::Set { .. } => CassValueType::CASS_VALUE_TYPE_SET, - CassDataType::Map { .. } => CassValueType::CASS_VALUE_TYPE_MAP, - CassDataType::Tuple(..) => CassValueType::CASS_VALUE_TYPE_TUPLE, - CassDataType::Custom(..) => CassValueType::CASS_VALUE_TYPE_CUSTOM, + CassDataTypeInner::Value(value_data_type) => *value_data_type, + CassDataTypeInner::UDT { .. } => CassValueType::CASS_VALUE_TYPE_UDT, + CassDataTypeInner::List { .. } => CassValueType::CASS_VALUE_TYPE_LIST, + CassDataTypeInner::Set { .. } => CassValueType::CASS_VALUE_TYPE_SET, + CassDataTypeInner::Map { .. } => CassValueType::CASS_VALUE_TYPE_MAP, + CassDataTypeInner::Tuple(..) => CassValueType::CASS_VALUE_TYPE_TUPLE, + CassDataTypeInner::Custom(..) => CassValueType::CASS_VALUE_TYPE_CUSTOM, } } } pub fn get_column_type(column_type: &ColumnType) -> CassDataType { - match column_type { - ColumnType::Custom(s) => CassDataType::Custom(s.clone().into_owned()), - ColumnType::Ascii => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_ASCII), - ColumnType::Boolean => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BOOLEAN), - ColumnType::Blob => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BLOB), - ColumnType::Counter => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_COUNTER), - ColumnType::Decimal => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_DECIMAL), - ColumnType::Date => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_DATE), - ColumnType::Double => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_DOUBLE), - ColumnType::Float => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_FLOAT), - ColumnType::Int => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_INT), - ColumnType::BigInt => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BIGINT), - ColumnType::Text => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_TEXT), - ColumnType::Timestamp => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_TIMESTAMP), - ColumnType::Inet => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_INET), - ColumnType::Duration => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_DURATION), - ColumnType::List(boxed_type) => CassDataType::List { + let inner = match column_type { + ColumnType::Custom(s) => CassDataTypeInner::Custom(s.clone().into_owned()), + ColumnType::Ascii => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_ASCII), + ColumnType::Boolean => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_BOOLEAN), + ColumnType::Blob => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_BLOB), + ColumnType::Counter => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_COUNTER), + ColumnType::Decimal => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_DECIMAL), + ColumnType::Date => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_DATE), + ColumnType::Double => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_DOUBLE), + ColumnType::Float => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_FLOAT), + ColumnType::Int => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_INT), + ColumnType::BigInt => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_BIGINT), + ColumnType::Text => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_TEXT), + ColumnType::Timestamp => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_TIMESTAMP), + ColumnType::Inet => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_INET), + ColumnType::Duration => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_DURATION), + ColumnType::List(boxed_type) => CassDataTypeInner::List { typ: Some(Arc::new(get_column_type(boxed_type.as_ref()))), frozen: false, }, - ColumnType::Map(key, value) => CassDataType::Map { + ColumnType::Map(key, value) => CassDataTypeInner::Map { typ: MapDataType::KeyAndValue( Arc::new(get_column_type(key.as_ref())), Arc::new(get_column_type(value.as_ref())), ), frozen: false, }, - ColumnType::Set(boxed_type) => CassDataType::Set { + ColumnType::Set(boxed_type) => CassDataTypeInner::Set { typ: Some(Arc::new(get_column_type(boxed_type.as_ref()))), frozen: false, }, @@ -462,7 +515,7 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType { type_name, keyspace, field_types, - } => CassDataType::UDT(UDTDataType { + } => CassDataTypeInner::UDT(UDTDataType { field_types: field_types .iter() .map(|(name, col_type)| { @@ -476,18 +529,20 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType { name: type_name.clone().into_owned(), frozen: false, }), - ColumnType::SmallInt => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_SMALL_INT), - ColumnType::TinyInt => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_TINY_INT), - ColumnType::Time => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_TIME), - ColumnType::Timeuuid => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_TIMEUUID), - ColumnType::Tuple(v) => CassDataType::Tuple( + ColumnType::SmallInt => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_SMALL_INT), + ColumnType::TinyInt => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_TINY_INT), + ColumnType::Time => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_TIME), + ColumnType::Timeuuid => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_TIMEUUID), + ColumnType::Tuple(v) => CassDataTypeInner::Tuple( v.iter() .map(|col_type| Arc::new(get_column_type(col_type))) .collect(), ), - ColumnType::Uuid => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_UUID), - ColumnType::Varint => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_VARINT), - } + ColumnType::Uuid => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_UUID), + ColumnType::Varint => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_VARINT), + }; + + CassDataType::new(inner) } // Changed return type to const ptr - Arc::into_raw is const. @@ -496,27 +551,28 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType { // This comment also applies to other functions that create CassDataType. #[no_mangle] pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const CassDataType { - let data_type = match value_type { - CassValueType::CASS_VALUE_TYPE_LIST => CassDataType::List { + let inner = match value_type { + CassValueType::CASS_VALUE_TYPE_LIST => CassDataTypeInner::List { typ: None, frozen: false, }, - CassValueType::CASS_VALUE_TYPE_SET => CassDataType::Set { + CassValueType::CASS_VALUE_TYPE_SET => CassDataTypeInner::Set { typ: None, frozen: false, }, - CassValueType::CASS_VALUE_TYPE_TUPLE => CassDataType::Tuple(Vec::new()), - CassValueType::CASS_VALUE_TYPE_MAP => CassDataType::Map { + CassValueType::CASS_VALUE_TYPE_TUPLE => CassDataTypeInner::Tuple(Vec::new()), + CassValueType::CASS_VALUE_TYPE_MAP => CassDataTypeInner::Map { typ: MapDataType::Untyped, frozen: false, }, - CassValueType::CASS_VALUE_TYPE_UDT => CassDataType::UDT(UDTDataType::new()), - CassValueType::CASS_VALUE_TYPE_CUSTOM => CassDataType::Custom("".to_string()), + CassValueType::CASS_VALUE_TYPE_UDT => CassDataTypeInner::UDT(UDTDataType::new()), + CassValueType::CASS_VALUE_TYPE_CUSTOM => CassDataTypeInner::Custom("".to_string()), CassValueType::CASS_VALUE_TYPE_UNKNOWN => return ptr::null_mut(), - t if t < CassValueType::CASS_VALUE_TYPE_LAST_ENTRY => CassDataType::Value(t), + t if t < CassValueType::CASS_VALUE_TYPE_LAST_ENTRY => CassDataTypeInner::Value(t), _ => return ptr::null_mut(), }; - Arc::into_raw(Arc::new(data_type)) + + Arc::into_raw(CassDataType::new_arced(inner)) } #[no_mangle] @@ -524,21 +580,21 @@ pub unsafe extern "C" fn cass_data_type_new_from_existing( data_type: *const CassDataType, ) -> *const CassDataType { let data_type = ptr_to_ref(data_type); - Arc::into_raw(Arc::new(data_type.clone())) + Arc::into_raw(CassDataType::new_arced(data_type.get_unchecked().clone())) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_new_tuple(item_count: size_t) -> *const CassDataType { - Arc::into_raw(Arc::new(CassDataType::Tuple(Vec::with_capacity( - item_count as usize, - )))) + Arc::into_raw(CassDataType::new_arced(CassDataTypeInner::Tuple( + Vec::with_capacity(item_count as usize), + ))) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_new_udt(field_count: size_t) -> *const CassDataType { - Arc::into_raw(Arc::new(CassDataType::UDT(UDTDataType::with_capacity( - field_count as usize, - )))) + Arc::into_raw(CassDataType::new_arced(CassDataTypeInner::UDT( + UDTDataType::with_capacity(field_count as usize), + ))) } #[no_mangle] @@ -549,17 +605,17 @@ pub unsafe extern "C" fn cass_data_type_free(data_type: *mut CassDataType) { #[no_mangle] pub unsafe extern "C" fn cass_data_type_type(data_type: *const CassDataType) -> CassValueType { let data_type = ptr_to_ref(data_type); - data_type.get_value_type() + data_type.get_unchecked().get_value_type() } #[no_mangle] pub unsafe extern "C" fn cass_data_type_is_frozen(data_type: *const CassDataType) -> cass_bool_t { let data_type = ptr_to_ref(data_type); - let is_frozen = match data_type { - CassDataType::UDT(udt) => udt.frozen, - CassDataType::List { frozen, .. } => *frozen, - CassDataType::Set { frozen, .. } => *frozen, - CassDataType::Map { frozen, .. } => *frozen, + let is_frozen = match data_type.get_unchecked() { + CassDataTypeInner::UDT(udt) => udt.frozen, + CassDataTypeInner::List { frozen, .. } => *frozen, + CassDataTypeInner::Set { frozen, .. } => *frozen, + CassDataTypeInner::Map { frozen, .. } => *frozen, _ => false, }; @@ -573,8 +629,8 @@ pub unsafe extern "C" fn cass_data_type_type_name( type_name_length: *mut size_t, ) -> CassError { let data_type = ptr_to_ref(data_type); - match data_type { - CassDataType::UDT(UDTDataType { name, .. }) => { + match data_type.get_unchecked() { + CassDataTypeInner::UDT(UDTDataType { name, .. }) => { write_str_to_c(name, type_name, type_name_length); CassError::CASS_OK } @@ -584,7 +640,7 @@ pub unsafe extern "C" fn cass_data_type_type_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_type_name( - data_type: *mut CassDataType, + data_type: *const CassDataType, type_name: *const c_char, ) -> CassError { cass_data_type_set_type_name_n(data_type, type_name, strlen(type_name)) @@ -592,17 +648,17 @@ pub unsafe extern "C" fn cass_data_type_set_type_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_type_name_n( - data_type_raw: *mut CassDataType, + data_type_raw: *const CassDataType, type_name: *const c_char, type_name_length: size_t, ) -> CassError { - let data_type = ptr_to_ref_mut(data_type_raw); + let data_type = ptr_to_ref(data_type_raw); let type_name_string = ptr_to_cstr_n(type_name, type_name_length) .unwrap() .to_string(); - match data_type { - CassDataType::UDT(udt_data_type) => { + match data_type.get_mut_unchecked() { + CassDataTypeInner::UDT(udt_data_type) => { udt_data_type.name = type_name_string; CassError::CASS_OK } @@ -617,8 +673,8 @@ pub unsafe extern "C" fn cass_data_type_keyspace( keyspace_length: *mut size_t, ) -> CassError { let data_type = ptr_to_ref(data_type); - match data_type { - CassDataType::UDT(UDTDataType { name, .. }) => { + match data_type.get_unchecked() { + CassDataTypeInner::UDT(UDTDataType { name, .. }) => { write_str_to_c(name, keyspace, keyspace_length); CassError::CASS_OK } @@ -628,7 +684,7 @@ pub unsafe extern "C" fn cass_data_type_keyspace( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_keyspace( - data_type: *mut CassDataType, + data_type: *const CassDataType, keyspace: *const c_char, ) -> CassError { cass_data_type_set_keyspace_n(data_type, keyspace, strlen(keyspace)) @@ -636,17 +692,17 @@ pub unsafe extern "C" fn cass_data_type_set_keyspace( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_keyspace_n( - data_type: *mut CassDataType, + data_type: *const CassDataType, keyspace: *const c_char, keyspace_length: size_t, ) -> CassError { - let data_type = ptr_to_ref_mut(data_type); + let data_type = ptr_to_ref(data_type); let keyspace_string = ptr_to_cstr_n(keyspace, keyspace_length) .unwrap() .to_string(); - match data_type { - CassDataType::UDT(udt_data_type) => { + match data_type.get_mut_unchecked() { + CassDataTypeInner::UDT(udt_data_type) => { udt_data_type.keyspace = keyspace_string; CassError::CASS_OK } @@ -661,8 +717,8 @@ pub unsafe extern "C" fn cass_data_type_class_name( class_name_length: *mut size_t, ) -> CassError { let data_type = ptr_to_ref(data_type); - match data_type { - CassDataType::Custom(name) => { + match data_type.get_unchecked() { + CassDataTypeInner::Custom(name) => { write_str_to_c(name, class_name, class_name_length); CassError::CASS_OK } @@ -672,7 +728,7 @@ pub unsafe extern "C" fn cass_data_type_class_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_class_name( - data_type: *mut CassDataType, + data_type: *const CassDataType, class_name: *const ::std::os::raw::c_char, ) -> CassError { cass_data_type_set_class_name_n(data_type, class_name, strlen(class_name)) @@ -680,16 +736,16 @@ pub unsafe extern "C" fn cass_data_type_set_class_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_class_name_n( - data_type: *mut CassDataType, + data_type: *const CassDataType, class_name: *const ::std::os::raw::c_char, class_name_length: size_t, ) -> CassError { - let data_type = ptr_to_ref_mut(data_type); + let data_type = ptr_to_ref(data_type); let class_string = ptr_to_cstr_n(class_name, class_name_length) .unwrap() .to_string(); - match data_type { - CassDataType::Custom(name) => { + match data_type.get_mut_unchecked() { + CassDataTypeInner::Custom(name) => { *name = class_string; CassError::CASS_OK } @@ -700,17 +756,19 @@ pub unsafe extern "C" fn cass_data_type_set_class_name_n( #[no_mangle] pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDataType) -> size_t { let data_type = ptr_to_ref(data_type); - match data_type { - CassDataType::Value(..) => 0, - CassDataType::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t, - CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => typ.is_some() as size_t, - CassDataType::Map { typ, .. } => match typ { + match data_type.get_unchecked() { + CassDataTypeInner::Value(..) => 0, + CassDataTypeInner::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t, + CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => { + typ.is_some() as size_t + } + CassDataTypeInner::Map { typ, .. } => match typ { MapDataType::Untyped => 0, MapDataType::Key(_) => 1, MapDataType::KeyAndValue(_, _) => 2, }, - CassDataType::Tuple(v) => v.len() as size_t, - CassDataType::Custom(..) => 0, + CassDataTypeInner::Tuple(v) => v.len() as size_t, + CassDataTypeInner::Custom(..) => 0, } } @@ -725,7 +783,8 @@ pub unsafe extern "C" fn cass_data_type_sub_data_type( index: size_t, ) -> *const CassDataType { let data_type = ptr_to_ref(data_type); - let sub_type: Option<&Arc> = data_type.get_sub_data_type(index as usize); + let sub_type: Option<&Arc> = + data_type.get_unchecked().get_sub_data_type(index as usize); match sub_type { None => std::ptr::null(), @@ -750,8 +809,8 @@ pub unsafe extern "C" fn cass_data_type_sub_data_type_by_name_n( ) -> *const CassDataType { let data_type = ptr_to_ref(data_type); let name_str = ptr_to_cstr_n(name, name_length).unwrap(); - match data_type { - CassDataType::UDT(udt) => match udt.get_field_by_name(name_str) { + match data_type.get_unchecked() { + CassDataTypeInner::UDT(udt) => match udt.get_field_by_name(name_str) { None => std::ptr::null(), Some(t) => Arc::as_ptr(t), }, @@ -767,8 +826,8 @@ pub unsafe extern "C" fn cass_data_type_sub_type_name( name_length: *mut size_t, ) -> CassError { let data_type = ptr_to_ref(data_type); - match data_type { - CassDataType::UDT(udt) => match udt.field_types.get(index as usize) { + match data_type.get_unchecked() { + CassDataTypeInner::UDT(udt) => match udt.field_types.get(index as usize) { None => CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, Some((field_name, _)) => { write_str_to_c(field_name, name, name_length); @@ -781,11 +840,14 @@ pub unsafe extern "C" fn cass_data_type_sub_type_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_type( - data_type: *mut CassDataType, + data_type: *const CassDataType, sub_data_type: *const CassDataType, ) -> CassError { - let data_type = ptr_to_ref_mut(data_type); - match data_type.add_sub_data_type(clone_arced(sub_data_type)) { + let data_type = ptr_to_ref(data_type); + match data_type + .get_mut_unchecked() + .add_sub_data_type(clone_arced(sub_data_type)) + { Ok(()) => CassError::CASS_OK, Err(e) => e, } @@ -793,7 +855,7 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type( #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name( - data_type: *mut CassDataType, + data_type: *const CassDataType, name: *const c_char, sub_data_type: *const CassDataType, ) -> CassError { @@ -802,7 +864,7 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name_n( - data_type_raw: *mut CassDataType, + data_type_raw: *const CassDataType, name: *const c_char, name_length: size_t, sub_data_type_raw: *const CassDataType, @@ -810,9 +872,9 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name_n( let name_string = ptr_to_cstr_n(name, name_length).unwrap().to_string(); let sub_data_type = clone_arced(sub_data_type_raw); - let data_type = ptr_to_ref_mut(data_type_raw); - match data_type { - CassDataType::UDT(udt_data_type) => { + let data_type = ptr_to_ref(data_type_raw); + match data_type.get_mut_unchecked() { + CassDataTypeInner::UDT(udt_data_type) => { // The Cpp Driver does not check whether field_types size // exceeded field_count. udt_data_type.field_types.push((name_string, sub_data_type)); @@ -824,31 +886,31 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name_n( #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_value_type( - data_type: *mut CassDataType, + data_type: *const CassDataType, sub_value_type: CassValueType, ) -> CassError { - let sub_data_type = Arc::new(CassDataType::Value(sub_value_type)); + let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); cass_data_type_add_sub_type(data_type, Arc::as_ptr(&sub_data_type)) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name( - data_type: *mut CassDataType, + data_type: *const CassDataType, name: *const c_char, sub_value_type: CassValueType, ) -> CassError { - let sub_data_type = Arc::new(CassDataType::Value(sub_value_type)); + let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); cass_data_type_add_sub_type_by_name(data_type, name, Arc::as_ptr(&sub_data_type)) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name_n( - data_type: *mut CassDataType, + data_type: *const CassDataType, name: *const c_char, name_length: size_t, sub_value_type: CassValueType, ) -> CassError { - let sub_data_type = Arc::new(CassDataType::Value(sub_value_type)); + let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); cass_data_type_add_sub_type_by_name_n(data_type, name, name_length, Arc::as_ptr(&sub_data_type)) } diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index a396e7ae..074bf5f0 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -1,6 +1,6 @@ use crate::cass_collection_types::CassCollectionType; use crate::cass_error::CassError; -use crate::cass_types::{CassDataType, MapDataType}; +use crate::cass_types::{CassDataType, CassDataTypeInner, MapDataType}; use crate::types::*; use crate::value::CassCqlValue; use crate::{argconv::*, value}; @@ -8,18 +8,18 @@ use std::convert::TryFrom; use std::sync::Arc; // These constants help us to save an allocation in case user calls `cass_collection_new` (untyped collection). -static UNTYPED_LIST_TYPE: CassDataType = CassDataType::List { +static UNTYPED_LIST_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::List { typ: None, frozen: false, -}; -static UNTYPED_SET_TYPE: CassDataType = CassDataType::Set { +}); +static UNTYPED_SET_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::Set { typ: None, frozen: false, -}; -static UNTYPED_MAP_TYPE: CassDataType = CassDataType::Map { +}); +static UNTYPED_MAP_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::Map { typ: MapDataType::Untyped, frozen: false, -}; +}); #[derive(Clone)] pub struct CassCollection { @@ -35,10 +35,14 @@ impl CassCollection { let index = self.items.len(); // Do validation only if it's a typed collection. - if let Some(data_type) = &self.data_type { - match data_type.as_ref() { - CassDataType::List { typ: subtype, .. } - | CassDataType::Set { typ: subtype, .. } => { + if let Some(data_type) = &self + .data_type + .as_ref() + .map(|dt| unsafe { dt.get_unchecked() }) + { + match data_type { + CassDataTypeInner::List { typ: subtype, .. } + | CassDataTypeInner::Set { typ: subtype, .. } => { if let Some(subtype) = subtype { if !value::is_type_compatible(value, subtype) { return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; @@ -46,7 +50,7 @@ impl CassCollection { } } - CassDataType::Map { typ, .. } => { + CassDataTypeInner::Map { typ, .. } => { // Cpp-driver does the typecheck only if both map types are present... // However, we decided not to mimic this behaviour (which is probably a bug). // We will do the typecheck if just the key type is defined as well (half-typed maps). @@ -146,12 +150,16 @@ unsafe extern "C" fn cass_collection_new_from_data_type( item_count: size_t, ) -> *mut CassCollection { let data_type = clone_arced(data_type); - let (capacity, collection_type) = match data_type.as_ref() { - CassDataType::List { .. } => (item_count, CassCollectionType::CASS_COLLECTION_TYPE_LIST), - CassDataType::Set { .. } => (item_count, CassCollectionType::CASS_COLLECTION_TYPE_SET), + let (capacity, collection_type) = match data_type.get_unchecked() { + CassDataTypeInner::List { .. } => { + (item_count, CassCollectionType::CASS_COLLECTION_TYPE_LIST) + } + CassDataTypeInner::Set { .. } => (item_count, CassCollectionType::CASS_COLLECTION_TYPE_SET), // Maps consist of a key and a value, so twice // the number of CassCqlValue will be stored. - CassDataType::Map { .. } => (item_count * 2, CassCollectionType::CASS_COLLECTION_TYPE_MAP), + CassDataTypeInner::Map { .. } => { + (item_count * 2, CassCollectionType::CASS_COLLECTION_TYPE_MAP) + } _ => return std::ptr::null_mut(), }; let capacity = capacity as usize; @@ -216,7 +224,7 @@ mod tests { use crate::{ cass_error::CassError, - cass_types::{CassDataType, CassValueType, MapDataType}, + cass_types::{CassDataType, CassDataTypeInner, CassValueType, MapDataType}, collection::{ cass_collection_append_double, cass_collection_append_float, cass_collection_free, }, @@ -256,7 +264,7 @@ mod tests { // untyped map (via cass_collection_new_from_data_type - collection's type is Some(untyped_map)). { - let dt = Arc::new(CassDataType::Map { + let dt = CassDataType::new_arced(CassDataTypeInner::Map { typ: MapDataType::Untyped, frozen: false, }); @@ -285,8 +293,8 @@ mod tests { // half-typed map (key-only) { - let dt = Arc::new(CassDataType::Map { - typ: MapDataType::Key(Arc::new(CassDataType::Value( + let dt = CassDataType::new_arced(CassDataTypeInner::Map { + typ: MapDataType::Key(CassDataType::new_arced(CassDataTypeInner::Value( CassValueType::CASS_VALUE_TYPE_BOOLEAN, ))), frozen: false, @@ -324,10 +332,12 @@ mod tests { // typed map { - let dt = Arc::new(CassDataType::Map { + let dt = CassDataType::new_arced(CassDataTypeInner::Map { typ: MapDataType::KeyAndValue( - Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BOOLEAN)), - Arc::new(CassDataType::Value( + CassDataType::new_arced(CassDataTypeInner::Value( + CassValueType::CASS_VALUE_TYPE_BOOLEAN, + )), + CassDataType::new_arced(CassDataTypeInner::Value( CassValueType::CASS_VALUE_TYPE_SMALL_INT, )), ), @@ -383,7 +393,7 @@ mod tests { // untyped set (via cass_collection_new_from_data_type, collection's type is Some(untyped_set)) { - let dt = Arc::new(CassDataType::Set { + let dt = CassDataType::new_arced(CassDataTypeInner::Set { typ: None, frozen: false, }); @@ -404,8 +414,8 @@ mod tests { // typed set { - let dt = Arc::new(CassDataType::Set { - typ: Some(Arc::new(CassDataType::Value( + let dt = CassDataType::new_arced(CassDataTypeInner::Set { + typ: Some(CassDataType::new_arced(CassDataTypeInner::Value( CassValueType::CASS_VALUE_TYPE_BOOLEAN, ))), frozen: false, @@ -443,7 +453,7 @@ mod tests { // untyped list (via cass_collection_new_from_data_type, collection's type is Some(untyped_list)) { - let dt = Arc::new(CassDataType::Set { + let dt = CassDataType::new_arced(CassDataTypeInner::Set { typ: None, frozen: false, }); @@ -464,8 +474,8 @@ mod tests { // typed list { - let dt = Arc::new(CassDataType::Set { - typ: Some(Arc::new(CassDataType::Value( + let dt = CassDataType::new_arced(CassDataTypeInner::Set { + typ: Some(CassDataType::new_arced(CassDataTypeInner::Value( CassValueType::CASS_VALUE_TYPE_BOOLEAN, ))), frozen: false, diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 6f3383ec..78b8a7e7 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -1,7 +1,8 @@ use crate::argconv::*; use crate::cass_error::CassError; use crate::cass_types::{ - cass_data_type_type, get_column_type, CassColumnSpec, CassDataType, CassValueType, MapDataType, + cass_data_type_type, get_column_type, CassColumnSpec, CassDataType, CassDataTypeInner, + CassValueType, MapDataType, }; use crate::inet::CassInet; use crate::metadata::{ @@ -195,10 +196,10 @@ fn create_cass_row_columns(row: Row, metadata: &Arc) -> Vec< } fn get_column_value(column: CqlValue, column_type: &Arc) -> Value { - match (column, column_type.as_ref()) { + match (column, unsafe { column_type.get_unchecked() }) { ( CqlValue::List(list), - CassDataType::List { + CassDataTypeInner::List { typ: Some(list_type), .. }, @@ -212,7 +213,7 @@ fn get_column_value(column: CqlValue, column_type: &Arc) -> Value )), ( CqlValue::Map(map), - CassDataType::Map { + CassDataTypeInner::Map { typ: MapDataType::KeyAndValue(key_type, value_type), .. }, @@ -234,7 +235,7 @@ fn get_column_value(column: CqlValue, column_type: &Arc) -> Value )), ( CqlValue::Set(set), - CassDataType::Set { + CassDataTypeInner::Set { typ: Some(set_type), .. }, @@ -252,7 +253,7 @@ fn get_column_value(column: CqlValue, column_type: &Arc) -> Value type_name, fields, }, - CassDataType::UDT(udt_type), + CassDataTypeInner::UDT(udt_type), ) => CollectionValue(Collection::UserDefinedType { keyspace, type_name, @@ -274,7 +275,7 @@ fn get_column_value(column: CqlValue, column_type: &Arc) -> Value }) .collect(), }), - (CqlValue::Tuple(tuple), CassDataType::Tuple(tuple_types)) => { + (CqlValue::Tuple(tuple), CassDataTypeInner::Tuple(tuple_types)) => { CollectionValue(Collection::Tuple( tuple .into_iter() @@ -1472,7 +1473,7 @@ pub unsafe extern "C" fn cass_value_is_collection(value: *const CassValue) -> ca let val = ptr_to_ref(value); matches!( - val.value_type.get_value_type(), + val.value_type.get_unchecked().get_value_type(), CassValueType::CASS_VALUE_TYPE_LIST | CassValueType::CASS_VALUE_TYPE_SET | CassValueType::CASS_VALUE_TYPE_MAP @@ -1483,7 +1484,8 @@ pub unsafe extern "C" fn cass_value_is_collection(value: *const CassValue) -> ca pub unsafe extern "C" fn cass_value_is_duration(value: *const CassValue) -> cass_bool_t { let val = ptr_to_ref(value); - (val.value_type.get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION) as cass_bool_t + (val.value_type.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION) + as cass_bool_t } #[no_mangle] @@ -1508,15 +1510,15 @@ pub unsafe extern "C" fn cass_value_primary_sub_type( ) -> CassValueType { let val = ptr_to_ref(collection); - match val.value_type.as_ref() { - CassDataType::List { + match val.value_type.get_unchecked() { + CassDataTypeInner::List { typ: Some(list), .. - } => list.get_value_type(), - CassDataType::Set { typ: Some(set), .. } => set.get_value_type(), - CassDataType::Map { + } => list.get_unchecked().get_value_type(), + CassDataTypeInner::Set { typ: Some(set), .. } => set.get_unchecked().get_value_type(), + CassDataTypeInner::Map { typ: MapDataType::Key(key) | MapDataType::KeyAndValue(key, _), .. - } => key.get_value_type(), + } => key.get_unchecked().get_value_type(), _ => CassValueType::CASS_VALUE_TYPE_UNKNOWN, } } @@ -1527,11 +1529,11 @@ pub unsafe extern "C" fn cass_value_secondary_sub_type( ) -> CassValueType { let val = ptr_to_ref(collection); - match val.value_type.as_ref() { - CassDataType::Map { + match val.value_type.get_unchecked() { + CassDataTypeInner::Map { typ: MapDataType::KeyAndValue(_, value), .. - } => value.get_value_type(), + } => value.get_unchecked().get_value_type(), _ => CassValueType::CASS_VALUE_TYPE_UNKNOWN, } } @@ -1614,7 +1616,7 @@ mod tests { use crate::{ cass_error::CassError, - cass_types::{CassDataType, CassValueType}, + cass_types::{CassDataType, CassDataTypeInner, CassValueType}, query_result::{ cass_result_column_data_type, cass_result_column_name, cass_result_first_row, ptr_to_cstr_n, ptr_to_ref, size_t, @@ -1723,22 +1725,26 @@ mod tests { { let first_col_data_type = ptr_to_ref(cass_result_column_data_type(result_ptr, 0)); assert_eq!( - &CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BIGINT), + &CassDataType::new(CassDataTypeInner::Value( + CassValueType::CASS_VALUE_TYPE_BIGINT + )), first_col_data_type ); let second_col_data_type = ptr_to_ref(cass_result_column_data_type(result_ptr, 1)); assert_eq!( - &CassDataType::Value(CassValueType::CASS_VALUE_TYPE_VARINT), + &CassDataType::new(CassDataTypeInner::Value( + CassValueType::CASS_VALUE_TYPE_VARINT + )), second_col_data_type ); let third_col_data_type = ptr_to_ref(cass_result_column_data_type(result_ptr, 2)); assert_eq!( - &CassDataType::List { - typ: Some(Arc::new(CassDataType::Value( + &CassDataType::new(CassDataTypeInner::List { + typ: Some(CassDataType::new_arced(CassDataTypeInner::Value( CassValueType::CASS_VALUE_TYPE_DOUBLE ))), frozen: false - }, + }), third_col_data_type ); let out_of_bound_col_data_type = cass_result_column_data_type(result_ptr, 555); diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 5ec3cee4..8010adf5 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -1,7 +1,7 @@ use crate::argconv::*; use crate::batch::CassBatch; use crate::cass_error::*; -use crate::cass_types::{CassDataType, UDTDataType}; +use crate::cass_types::{CassDataType, CassDataTypeInner, UDTDataType}; use crate::cluster::build_session_builder; use crate::cluster::CassCluster; use crate::exec_profile::{CassExecProfile, ExecProfileName, PerStatementExecProfile}; @@ -509,7 +509,7 @@ pub unsafe extern "C" fn cass_session_get_schema_meta( for udt_name in keyspace.user_defined_types.keys() { user_defined_type_data_type.insert( udt_name.clone(), - Arc::new(CassDataType::UDT(UDTDataType::create_with_params( + CassDataType::new_arced(CassDataTypeInner::UDT(UDTDataType::create_with_params( &keyspace.user_defined_types, keyspace_name, udt_name, diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index df1f9889..fd8d5fa4 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -1,12 +1,13 @@ use crate::argconv::*; use crate::cass_error::CassError; use crate::cass_types::CassDataType; +use crate::cass_types::CassDataTypeInner; use crate::types::*; use crate::value; use crate::value::CassCqlValue; use std::sync::Arc; -static UNTYPED_TUPLE_TYPE: CassDataType = CassDataType::Tuple(Vec::new()); +static UNTYPED_TUPLE_TYPE: CassDataType = CassDataType::new(CassDataTypeInner::Tuple(Vec::new())); #[derive(Clone)] pub struct CassTuple { @@ -17,8 +18,8 @@ pub struct CassTuple { impl CassTuple { fn get_types(&self) -> Option<&Vec>> { match &self.data_type { - Some(t) => match &**t { - CassDataType::Tuple(v) => Some(v), + Some(t) => match unsafe { t.as_ref().get_unchecked() } { + CassDataTypeInner::Tuple(v) => Some(v), _ => unreachable!(), }, None => None, @@ -70,8 +71,8 @@ unsafe extern "C" fn cass_tuple_new_from_data_type( data_type: *const CassDataType, ) -> *mut CassTuple { let data_type = clone_arced(data_type); - let item_count = match &*data_type { - CassDataType::Tuple(v) => v.len(), + let item_count = match data_type.get_unchecked() { + CassDataTypeInner::Tuple(v) => v.len(), _ => return std::ptr::null_mut(), }; Box::into_raw(Box::new(CassTuple { diff --git a/scylla-rust-wrapper/src/user_type.rs b/scylla-rust-wrapper/src/user_type.rs index e723c7b7..1775ace7 100644 --- a/scylla-rust-wrapper/src/user_type.rs +++ b/scylla-rust-wrapper/src/user_type.rs @@ -1,5 +1,5 @@ use crate::cass_error::CassError; -use crate::cass_types::CassDataType; +use crate::cass_types::{CassDataType, CassDataTypeInner}; use crate::types::*; use crate::value::CassCqlValue; use crate::{argconv::*, value}; @@ -19,7 +19,9 @@ impl CassUserType { if index >= self.field_values.len() { return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS; } - if !value::is_type_compatible(&value, &self.data_type.get_udt_type().field_types[index].1) { + if !value::is_type_compatible(&value, unsafe { + &self.data_type.get_unchecked().get_udt_type().field_types[index].1 + }) { return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; } self.field_values[index] = value; @@ -28,9 +30,14 @@ impl CassUserType { fn set_field_by_name(&mut self, name: &str, value: Option) -> CassError { let mut found_field: bool = false; - for (index, (field_name, field_type)) in - self.data_type.get_udt_type().field_types.iter().enumerate() - { + for (index, (field_name, field_type)) in unsafe { + self.data_type + .get_unchecked() + .get_udt_type() + .field_types + .iter() + .enumerate() + } { if *field_name == name { found_field = true; if index >= self.field_values.len() { @@ -58,7 +65,14 @@ impl From<&CassUserType> for CassCqlValue { fields: user_type .field_values .iter() - .zip(user_type.data_type.get_udt_type().field_types.iter()) + .zip(unsafe { + user_type + .data_type + .get_unchecked() + .get_udt_type() + .field_types + .iter() + }) .map(|(v, (name, _))| (name.clone(), v.clone())) .collect(), } @@ -71,8 +85,8 @@ pub unsafe extern "C" fn cass_user_type_new_from_data_type( ) -> *mut CassUserType { let data_type = clone_arced(data_type_raw); - match &*data_type { - CassDataType::UDT(udt_data_type) => { + match data_type.get_unchecked() { + CassDataTypeInner::UDT(udt_data_type) => { let field_values = vec![None; udt_data_type.field_types.len()]; Box::into_raw(Box::new(CassUserType { data_type, diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index e0884f0e..006b4ef1 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -83,87 +83,103 @@ pub fn is_type_compatible(value: &Option, typ: &CassDataType) -> b impl CassCqlValue { pub fn is_type_compatible(&self, typ: &CassDataType) -> bool { match self { - CassCqlValue::TinyInt(_) => { - typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_TINY_INT - } - CassCqlValue::SmallInt(_) => { - typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_SMALL_INT - } - CassCqlValue::Int(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_INT, - CassCqlValue::BigInt(_) => { + CassCqlValue::TinyInt(_) => unsafe { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_TINY_INT + }, + CassCqlValue::SmallInt(_) => unsafe { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_SMALL_INT + }, + CassCqlValue::Int(_) => unsafe { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_INT + }, + CassCqlValue::BigInt(_) => unsafe { matches!( - typ.get_value_type(), + typ.get_unchecked().get_value_type(), CassValueType::CASS_VALUE_TYPE_BIGINT | CassValueType::CASS_VALUE_TYPE_COUNTER | CassValueType::CASS_VALUE_TYPE_TIMESTAMP | CassValueType::CASS_VALUE_TYPE_TIME ) - } - CassCqlValue::Float(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_FLOAT, - CassCqlValue::Double(_) => { - typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DOUBLE - } - CassCqlValue::Boolean(_) => { - typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_BOOLEAN - } - CassCqlValue::Text(_) => { + }, + CassCqlValue::Float(_) => unsafe { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_FLOAT + }, + CassCqlValue::Double(_) => unsafe { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_DOUBLE + }, + CassCqlValue::Boolean(_) => unsafe { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_BOOLEAN + }, + CassCqlValue::Text(_) => unsafe { matches!( - typ.get_value_type(), + typ.get_unchecked().get_value_type(), CassValueType::CASS_VALUE_TYPE_TEXT | CassValueType::CASS_VALUE_TYPE_VARCHAR | CassValueType::CASS_VALUE_TYPE_ASCII | CassValueType::CASS_VALUE_TYPE_BLOB | CassValueType::CASS_VALUE_TYPE_VARINT ) - } - CassCqlValue::Blob(_) => matches!( - typ.get_value_type(), - CassValueType::CASS_VALUE_TYPE_BLOB | CassValueType::CASS_VALUE_TYPE_VARINT - ), - CassCqlValue::Uuid(_) => matches!( - typ.get_value_type(), - CassValueType::CASS_VALUE_TYPE_UUID | CassValueType::CASS_VALUE_TYPE_TIMEUUID - ), - CassCqlValue::Date(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DATE, - CassCqlValue::Inet(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_INET, - CassCqlValue::Duration(_) => { - typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION - } - CassCqlValue::Decimal(_) => { - typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DECIMAL - } - CassCqlValue::Tuple { data_type, .. } => { + }, + CassCqlValue::Blob(_) => unsafe { + matches!( + typ.get_unchecked().get_value_type(), + CassValueType::CASS_VALUE_TYPE_BLOB | CassValueType::CASS_VALUE_TYPE_VARINT + ) + }, + CassCqlValue::Uuid(_) => unsafe { + matches!( + typ.get_unchecked().get_value_type(), + CassValueType::CASS_VALUE_TYPE_UUID | CassValueType::CASS_VALUE_TYPE_TIMEUUID + ) + }, + CassCqlValue::Date(_) => unsafe { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_DATE + }, + CassCqlValue::Inet(_) => unsafe { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_INET + }, + CassCqlValue::Duration(_) => unsafe { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION + }, + CassCqlValue::Decimal(_) => unsafe { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_DECIMAL + }, + CassCqlValue::Tuple { data_type, .. } => unsafe { if let Some(dt) = data_type { - return dt.typecheck_equals(typ); + return dt.get_unchecked().typecheck_equals(typ.get_unchecked()); } // Untyped tuple. - typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_TUPLE - } - CassCqlValue::List { data_type, .. } => { + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_TUPLE + }, + CassCqlValue::List { data_type, .. } => unsafe { if let Some(dt) = data_type { - dt.typecheck_equals(typ) + dt.get_unchecked().typecheck_equals(typ.get_unchecked()) } else { // Untyped list. - typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_LIST + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_LIST } - } - CassCqlValue::Map { data_type, .. } => { + }, + CassCqlValue::Map { data_type, .. } => unsafe { if let Some(dt) = data_type { - dt.typecheck_equals(typ) + dt.get_unchecked().typecheck_equals(typ.get_unchecked()) } else { // Untyped map. - typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_MAP + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_MAP } - } - CassCqlValue::Set { data_type, .. } => { + }, + CassCqlValue::Set { data_type, .. } => unsafe { if let Some(dt) = data_type { - dt.typecheck_equals(typ) + dt.get_unchecked().typecheck_equals(typ.get_unchecked()) } else { // Untyped set. - typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_SET + typ.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_SET } - } - CassCqlValue::UserDefinedType { data_type, .. } => data_type.typecheck_equals(typ), + }, + CassCqlValue::UserDefinedType { data_type, .. } => unsafe { + data_type + .get_unchecked() + .typecheck_equals(typ.get_unchecked()) + }, } } } @@ -402,14 +418,14 @@ mod tests { use scylla::frame::value::{CqlDate, CqlDecimal, CqlDuration}; use crate::{ - cass_types::{CassDataType, CassValueType, MapDataType, UDTDataType}, + cass_types::{CassDataType, CassDataTypeInner, CassValueType, MapDataType, UDTDataType}, value::{is_type_compatible, CassCqlValue}, }; - fn all_value_data_types() -> [CassDataType; 26] { - let from = |v_typ: CassValueType| CassDataType::Value(v_typ); + fn all_value_data_types() -> Vec { + let from = |v_typ: CassValueType| CassDataType::new(CassDataTypeInner::Value(v_typ)); - [ + vec![ from(CassValueType::CASS_VALUE_TYPE_TINY_INT), from(CassValueType::CASS_VALUE_TYPE_SMALL_INT), from(CassValueType::CASS_VALUE_TYPE_INT), @@ -441,7 +457,7 @@ mod tests { #[test] fn typecheck_simple_test() { - let from = |v_typ: CassValueType| CassDataType::Value(v_typ); + let from = |v_typ: CassValueType| CassDataType::new(CassDataTypeInner::Value(v_typ)); struct TestCase { value: Option, compatible_types: Vec, @@ -451,7 +467,7 @@ mod tests { // Null -> all types TestCase { value: None, - compatible_types: all_value_data_types().to_vec(), + compatible_types: all_value_data_types(), }, // i8 -> tinyint TestCase { @@ -594,10 +610,15 @@ mod tests { // Let's make some types accessible for all test cases. // To make sure that e.g. Tuple against UDT typecheck fails. - let data_type_float = Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_FLOAT)); - let data_type_int = Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_INT)); - let data_type_bool = Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BOOLEAN)); - let data_type_tuple = Arc::new(CassDataType::Tuple(vec![ + let data_type_float = CassDataType::new_arced(CassDataTypeInner::Value( + CassValueType::CASS_VALUE_TYPE_FLOAT, + )); + let data_type_int = + CassDataType::new_arced(CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_INT)); + let data_type_bool = CassDataType::new_arced(CassDataTypeInner::Value( + CassValueType::CASS_VALUE_TYPE_BOOLEAN, + )); + let data_type_tuple = CassDataType::new_arced(CassDataTypeInner::Tuple(vec![ data_type_float.clone(), data_type_int.clone(), data_type_bool.clone(), @@ -612,42 +633,44 @@ mod tests { let user_udt_name = "user".to_owned(); let empty_str = "".to_owned(); - let data_type_udt_simple = Arc::new(CassDataType::UDT(UDTDataType { + let data_type_udt_simple = CassDataType::new_arced(CassDataTypeInner::UDT(UDTDataType { field_types: simple_fields.clone(), keyspace: ks_keyspace_name.clone(), name: user_udt_name.clone(), frozen: false, })); - let data_type_int_list = Arc::new(CassDataType::List { + let data_type_int_list = CassDataType::new_arced(CassDataTypeInner::List { typ: Some(data_type_int.clone()), frozen: false, }); - let data_type_int_set = Arc::new(CassDataType::Set { + let data_type_int_set = CassDataType::new_arced(CassDataTypeInner::Set { typ: Some(data_type_int.clone()), frozen: false, }); - let data_type_bool_float_map = Arc::new(CassDataType::Map { + let data_type_bool_float_map = CassDataType::new_arced(CassDataTypeInner::Map { typ: MapDataType::KeyAndValue(data_type_bool.clone(), data_type_float.clone()), frozen: false, }); // TUPLES { - let data_type_untyped_tuple = Arc::new(CassDataType::Tuple(vec![])); - let data_type_small_tuple = Arc::new(CassDataType::Tuple(vec![data_type_bool.clone()])); - let data_type_nested_tuple = Arc::new(CassDataType::Tuple(vec![ + let data_type_untyped_tuple = CassDataType::new_arced(CassDataTypeInner::Tuple(vec![])); + let data_type_small_tuple = + CassDataType::new_arced(CassDataTypeInner::Tuple(vec![data_type_bool.clone()])); + let data_type_nested_tuple = CassDataType::new_arced(CassDataTypeInner::Tuple(vec![ data_type_small_tuple.clone(), data_type_int.clone(), data_type_tuple.clone(), ])); - let data_type_nested_untyped_tuple = Arc::new(CassDataType::Tuple(vec![ - data_type_untyped_tuple.clone(), - data_type_int.clone(), - data_type_untyped_tuple.clone(), - ])); + let data_type_nested_untyped_tuple = + CassDataType::new_arced(CassDataTypeInner::Tuple(vec![ + data_type_untyped_tuple.clone(), + data_type_int.clone(), + data_type_untyped_tuple.clone(), + ])); let test_cases = &[ // Untyped tuple -> created via `cass_tuple_new` @@ -748,30 +771,33 @@ mod tests { // UDT { - let data_type_udt_simple_empty_keyspace = Arc::new(CassDataType::UDT(UDTDataType { - field_types: simple_fields.clone(), - keyspace: empty_str.to_owned(), - name: user_udt_name.clone(), - frozen: false, - })); - let data_type_udt_simple_empty_name = Arc::new(CassDataType::UDT(UDTDataType { - field_types: simple_fields.clone(), - keyspace: ks_keyspace_name.clone(), - name: empty_str.clone(), - frozen: false, - })); + let data_type_udt_simple_empty_keyspace = + CassDataType::new_arced(CassDataTypeInner::UDT(UDTDataType { + field_types: simple_fields.clone(), + keyspace: empty_str.to_owned(), + name: user_udt_name.clone(), + frozen: false, + })); + let data_type_udt_simple_empty_name = + CassDataType::new_arced(CassDataTypeInner::UDT(UDTDataType { + field_types: simple_fields.clone(), + keyspace: ks_keyspace_name.clone(), + name: empty_str.clone(), + frozen: false, + })); // A prefix of simple_fields. let small_fields = vec![ ("foo".to_owned(), data_type_float.clone()), ("bar".to_owned(), data_type_bool.clone()), ]; - let data_type_udt_small = Arc::new(CassDataType::UDT(UDTDataType { - field_types: small_fields.clone(), - keyspace: ks_keyspace_name.clone(), - name: user_udt_name.clone(), - frozen: false, - })); + let data_type_udt_small = + CassDataType::new_arced(CassDataTypeInner::UDT(UDTDataType { + field_types: small_fields.clone(), + keyspace: ks_keyspace_name.clone(), + name: user_udt_name.clone(), + frozen: false, + })); let test_cases = &[TestCase { value: CassCqlValue::UserDefinedType { @@ -800,34 +826,34 @@ mod tests { // COLLECTIONS { - let data_type_untyped_list = Arc::new(CassDataType::List { + let data_type_untyped_list = CassDataType::new_arced(CassDataTypeInner::List { typ: None, frozen: false, }); - let data_type_float_list = Arc::new(CassDataType::List { + let data_type_float_list = CassDataType::new_arced(CassDataTypeInner::List { typ: Some(data_type_float.clone()), frozen: false, }); - let data_type_untyped_set = Arc::new(CassDataType::Set { + let data_type_untyped_set = CassDataType::new_arced(CassDataTypeInner::Set { typ: None, frozen: false, }); - let data_type_float_set = Arc::new(CassDataType::Set { + let data_type_float_set = CassDataType::new_arced(CassDataTypeInner::Set { typ: Some(data_type_float.clone()), frozen: false, }); - let data_type_untyped_map = Arc::new(CassDataType::Map { + let data_type_untyped_map = CassDataType::new_arced(CassDataTypeInner::Map { typ: MapDataType::Untyped, frozen: false, }); - let data_type_typed_key_float_map = Arc::new(CassDataType::Map { + let data_type_typed_key_float_map = CassDataType::new_arced(CassDataTypeInner::Map { typ: MapDataType::Key(data_type_float.clone()), frozen: false, }); - let data_type_float_int_map = Arc::new(CassDataType::Map { + let data_type_float_int_map = CassDataType::new_arced(CassDataTypeInner::Map { typ: MapDataType::KeyAndValue(data_type_float.clone(), data_type_int.clone()), frozen: false, }); From 7468390f62a0e91a2fd7c36f52411eca792075d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Uzarski?= Date: Tue, 26 Nov 2024 12:18:05 +0100 Subject: [PATCH 2/5] types: derive/impl PartialEq and Eq only for test purposes Implementation of `PartialEq` for `CassDataType` hides a possibly unsafe operation. Let's make sure that we do not depend on it in the code - only use it for test purposes. --- scylla-rust-wrapper/src/cass_types.rs | 11 ++++++++--- scylla-rust-wrapper/src/value.rs | 3 ++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 38197455..4f3f7f11 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -16,7 +16,8 @@ pub(crate) use crate::cass_batch_types::CassBatchType; pub(crate) use crate::cass_consistency_types::CassConsistency; pub(crate) use crate::cass_data_types::CassValueType; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq, Eq))] pub struct UDTDataType { // Vec to preserve the order of types pub field_types: Vec<(String, Arc)>, @@ -137,7 +138,8 @@ impl Default for UDTDataType { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq, Eq))] pub enum MapDataType { Untyped, Key(Arc), @@ -150,7 +152,8 @@ pub struct CassColumnSpec { pub data_type: Arc, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq, Eq))] pub enum CassDataTypeInner { Value(CassValueType), UDT(UDTDataType), @@ -260,11 +263,13 @@ impl CassDataTypeInner { pub struct CassDataType(UnsafeCell); /// PartialEq and Eq for test purposes. +#[cfg(test)] impl PartialEq for CassDataType { fn eq(&self, other: &Self) -> bool { unsafe { self.get_unchecked() == other.get_unchecked() } } } +#[cfg(test)] impl Eq for CassDataType {} unsafe impl Sync for CassDataType {} diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 006b4ef1..86687169 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -31,7 +31,8 @@ use crate::cass_types::{CassDataType, CassValueType}; /// /// There is no such method as `cass_statement_bind_counter`, and so /// we need to serialize the counter value using `CassCqlValue::BigInt`. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq))] pub enum CassCqlValue { TinyInt(i8), SmallInt(i16), From 209a9b97e5b098a1c70f357aab613905e23d08d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20Bary=C5=82a?= Date: Fri, 23 Dec 2022 19:58:48 +0100 Subject: [PATCH 3/5] argconv: Define new traits for a bit safer FFI The reasons behind the API (and this PR) are following: - reduce the lifetime of reference obtained from the pointer Current pointer-to-ref conversion API (ptr_to_ref_[mut]) returns a reference with a static lifetime. New API is parameterized by a lifetime, which is not 'static. - Ensure that IF the memory associated with the pointer was allocated a certain way, the raw pointer will be converted to the corresponding Rust pointer primitive i.e. Box or Arc (or reference if the pointer was not obtained from an explicit allocation). However, this does not give us any guarantees about the origin of the pointer. Consider following example: // Rust impl ArcFFI for Foo {} fn extern "C" f1() -> *const Foo { // a pointer to stack variable // Also applies to some valid pointer obtained from the reference // to the field of some already heap-allocated object. // Decided to go with a stack variable to keep the example simple. let foo = Foo; &foo } fn extern "C" f2(foo: *const Foo) { let foo = ArcFFI::cloned_from_ptr(foo); } // C Foo *foo = f1(); f2(foo); // Segfault. // Even if f1() returned a valid pointer, that points to some // heap-allocated memory. The pointer was not obtained from an Arc allocation. // I.e., it was not obtained via Arc::into_raw(). To guarantee this, we need to introduce a special type for pointer that would represent the pointer's properties. This will be done in a follow-up PR. --- scylla-rust-wrapper/src/argconv.rs | 74 ++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/scylla-rust-wrapper/src/argconv.rs b/scylla-rust-wrapper/src/argconv.rs index 40a15737..db244700 100644 --- a/scylla-rust-wrapper/src/argconv.rs +++ b/scylla-rust-wrapper/src/argconv.rs @@ -97,3 +97,77 @@ macro_rules! make_c_str { #[cfg(test)] pub(crate) use make_c_str; + +/// Defines a pointer manipulation API for non-shared heap-allocated data. +/// +/// Implement this trait for types that are allocated by the driver via [`Box::new`], +/// and then returned to the user as a pointer. The user is responsible for freeing +/// the memory associated with the pointer using corresponding driver's API function. +pub trait BoxFFI { + fn into_ptr(self: Box) -> *mut Self { + Box::into_raw(self) + } + unsafe fn from_ptr(ptr: *mut Self) -> Box { + Box::from_raw(ptr) + } + unsafe fn as_maybe_ref<'a>(ptr: *const Self) -> Option<&'a Self> { + ptr.as_ref() + } + unsafe fn as_ref<'a>(ptr: *const Self) -> &'a Self { + ptr.as_ref().unwrap() + } + unsafe fn as_mut_ref<'a>(ptr: *mut Self) -> &'a mut Self { + ptr.as_mut().unwrap() + } + unsafe fn free(ptr: *mut Self) { + std::mem::drop(BoxFFI::from_ptr(ptr)); + } +} + +/// Defines a pointer manipulation API for shared heap-allocated data. +/// +/// Implement this trait for types that require a shared ownership of data. +/// The data should be allocated via [`Arc::new`], and then returned to the user as a pointer. +/// The user is responsible for freeing the memory associated +/// with the pointer using corresponding driver's API function. +pub trait ArcFFI { + fn as_ptr(self: &Arc) -> *const Self { + Arc::as_ptr(self) + } + fn into_ptr(self: Arc) -> *const Self { + Arc::into_raw(self) + } + unsafe fn from_ptr(ptr: *const Self) -> Arc { + Arc::from_raw(ptr) + } + unsafe fn cloned_from_ptr(ptr: *const Self) -> Arc { + Arc::increment_strong_count(ptr); + Arc::from_raw(ptr) + } + unsafe fn as_maybe_ref<'a>(ptr: *const Self) -> Option<&'a Self> { + ptr.as_ref() + } + unsafe fn as_ref<'a>(ptr: *const Self) -> &'a Self { + ptr.as_ref().unwrap() + } + unsafe fn free(ptr: *const Self) { + std::mem::drop(ArcFFI::from_ptr(ptr)); + } +} + +/// Defines a pointer manipulation API for data owned by some other object. +/// +/// Implement this trait for the types that do not need to be freed (directly) by the user. +/// The lifetime of the data is bound to some other object owning it. +/// +/// For example: lifetime of CassRow is bound by the lifetime of CassResult. +/// There is no API function that frees the CassRow. It should be automatically +/// freed when user calls cass_result_free. +pub trait RefFFI { + fn as_ptr(&self) -> *const Self { + self as *const Self + } + unsafe fn as_ref<'a>(ptr: *const Self) -> &'a Self { + ptr.as_ref().unwrap() + } +} From 80df19e6321f583d88ab9a5a608533ce023bb2e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20Bary=C5=82a?= Date: Fri, 6 Jan 2023 19:04:25 +0100 Subject: [PATCH 4/5] treewide: Move to new FFI functions Implemented new traits for the types shared between C and Rust. Adjusted all places where ptr-to-ref (and vice-versa) conversions appear to use the new traits API. --- scylla-rust-wrapper/src/argconv.rs | 27 --- scylla-rust-wrapper/src/batch.rs | 28 ++-- scylla-rust-wrapper/src/binding.rs | 14 +- scylla-rust-wrapper/src/cass_types.rs | 64 ++++---- scylla-rust-wrapper/src/cluster.rs | 80 ++++----- scylla-rust-wrapper/src/collection.rs | 37 +++-- scylla-rust-wrapper/src/exec_profile.rs | 46 +++--- scylla-rust-wrapper/src/future.rs | 32 ++-- .../src/integration_testing.rs | 8 +- scylla-rust-wrapper/src/lib.rs | 2 +- scylla-rust-wrapper/src/logging.rs | 6 +- scylla-rust-wrapper/src/metadata.rs | 82 ++++++---- scylla-rust-wrapper/src/prepared.rs | 18 +- scylla-rust-wrapper/src/query_error.rs | 28 ++-- scylla-rust-wrapper/src/query_result.rs | 154 ++++++++++-------- scylla-rust-wrapper/src/retry_policy.rs | 15 +- scylla-rust-wrapper/src/session.rs | 48 +++--- scylla-rust-wrapper/src/ssl.rs | 16 +- scylla-rust-wrapper/src/statement.rs | 32 ++-- scylla-rust-wrapper/src/tuple.rs | 14 +- scylla-rust-wrapper/src/user_type.rs | 10 +- scylla-rust-wrapper/src/uuid.rs | 12 +- 22 files changed, 405 insertions(+), 368 deletions(-) diff --git a/scylla-rust-wrapper/src/argconv.rs b/scylla-rust-wrapper/src/argconv.rs index db244700..953e0ca2 100644 --- a/scylla-rust-wrapper/src/argconv.rs +++ b/scylla-rust-wrapper/src/argconv.rs @@ -4,33 +4,6 @@ use std::ffi::CStr; use std::os::raw::c_char; use std::sync::Arc; -pub unsafe fn ptr_to_ref(ptr: *const T) -> &'static T { - ptr.as_ref().unwrap() -} - -pub unsafe fn ptr_to_ref_mut(ptr: *mut T) -> &'static mut T { - ptr.as_mut().unwrap() -} - -pub unsafe fn free_boxed(ptr: *mut T) { - if !ptr.is_null() { - // This takes the ownership of the boxed value and drops it - let _ = Box::from_raw(ptr); - } -} - -pub unsafe fn clone_arced(ptr: *const T) -> Arc { - Arc::increment_strong_count(ptr); - Arc::from_raw(ptr) -} - -pub unsafe fn free_arced(ptr: *const T) { - if !ptr.is_null() { - // This decrements the arc's internal counter and potentially drops it - Arc::from_raw(ptr); - } -} - pub unsafe fn ptr_to_cstr(ptr: *const c_char) -> Option<&'static str> { CStr::from_ptr(ptr).to_str().ok() } diff --git a/scylla-rust-wrapper/src/batch.rs b/scylla-rust-wrapper/src/batch.rs index 3cdf36ee..fe18e548 100644 --- a/scylla-rust-wrapper/src/batch.rs +++ b/scylla-rust-wrapper/src/batch.rs @@ -1,4 +1,4 @@ -use crate::argconv::{free_boxed, ptr_to_ref, ptr_to_ref_mut}; +use crate::argconv::{ArcFFI, BoxFFI}; use crate::cass_error::CassError; use crate::cass_types::CassConsistency; use crate::cass_types::{make_batch_type, CassBatchType}; @@ -19,6 +19,8 @@ pub struct CassBatch { pub(crate) exec_profile: Option, } +impl BoxFFI for CassBatch {} + #[derive(Clone)] pub struct CassBatchState { pub batch: Batch, @@ -28,7 +30,7 @@ pub struct CassBatchState { #[no_mangle] pub unsafe extern "C" fn cass_batch_new(type_: CassBatchType) -> *mut CassBatch { if let Some(batch_type) = make_batch_type(type_) { - Box::into_raw(Box::new(CassBatch { + BoxFFI::into_ptr(Box::new(CassBatch { state: Arc::new(CassBatchState { batch: Batch::new(batch_type), bound_values: Vec::new(), @@ -43,7 +45,7 @@ pub unsafe extern "C" fn cass_batch_new(type_: CassBatchType) -> *mut CassBatch #[no_mangle] pub unsafe extern "C" fn cass_batch_free(batch: *mut CassBatch) { - free_boxed(batch) + BoxFFI::free(batch); } #[no_mangle] @@ -51,7 +53,7 @@ pub unsafe extern "C" fn cass_batch_set_consistency( batch: *mut CassBatch, consistency: CassConsistency, ) -> CassError { - let batch = ptr_to_ref_mut(batch); + let batch = BoxFFI::as_mut_ref(batch); let consistency = match consistency.try_into().ok() { Some(c) => c, None => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -68,7 +70,7 @@ pub unsafe extern "C" fn cass_batch_set_serial_consistency( batch: *mut CassBatch, serial_consistency: CassConsistency, ) -> CassError { - let batch = ptr_to_ref_mut(batch); + let batch = BoxFFI::as_mut_ref(batch); let serial_consistency = match serial_consistency.try_into().ok() { Some(c) => c, None => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -85,10 +87,10 @@ pub unsafe extern "C" fn cass_batch_set_retry_policy( batch: *mut CassBatch, retry_policy: *const CassRetryPolicy, ) -> CassError { - let batch = ptr_to_ref_mut(batch); + let batch = BoxFFI::as_mut_ref(batch); let maybe_arced_retry_policy: Option> = - retry_policy.as_ref().map(|policy| match policy { + ArcFFI::as_maybe_ref(retry_policy).map(|policy| match policy { CassRetryPolicy::DefaultRetryPolicy(default) => { default.clone() as Arc } @@ -108,7 +110,7 @@ pub unsafe extern "C" fn cass_batch_set_timestamp( batch: *mut CassBatch, timestamp: cass_int64_t, ) -> CassError { - let batch = ptr_to_ref_mut(batch); + let batch = BoxFFI::as_mut_ref(batch); Arc::make_mut(&mut batch.state) .batch @@ -122,7 +124,7 @@ pub unsafe extern "C" fn cass_batch_set_request_timeout( batch: *mut CassBatch, timeout_ms: cass_uint64_t, ) -> CassError { - let batch = ptr_to_ref_mut(batch); + let batch = BoxFFI::as_mut_ref(batch); batch.batch_request_timeout_ms = Some(timeout_ms); CassError::CASS_OK @@ -133,7 +135,7 @@ pub unsafe extern "C" fn cass_batch_set_is_idempotent( batch: *mut CassBatch, is_idempotent: cass_bool_t, ) -> CassError { - let batch = ptr_to_ref_mut(batch); + let batch = BoxFFI::as_mut_ref(batch); Arc::make_mut(&mut batch.state) .batch .set_is_idempotent(is_idempotent != 0); @@ -146,7 +148,7 @@ pub unsafe extern "C" fn cass_batch_set_tracing( batch: *mut CassBatch, enabled: cass_bool_t, ) -> CassError { - let batch = ptr_to_ref_mut(batch); + let batch = BoxFFI::as_mut_ref(batch); Arc::make_mut(&mut batch.state) .batch .set_tracing(enabled != 0); @@ -159,9 +161,9 @@ pub unsafe extern "C" fn cass_batch_add_statement( batch: *mut CassBatch, statement: *const CassStatement, ) -> CassError { - let batch = ptr_to_ref_mut(batch); + let batch = BoxFFI::as_mut_ref(batch); let state = Arc::make_mut(&mut batch.state); - let statement = ptr_to_ref(statement); + let statement = BoxFFI::as_ref(statement); match &statement.statement { Statement::Simple(q) => state.batch.append_statement(q.query.clone()), diff --git a/scylla-rust-wrapper/src/binding.rs b/scylla-rust-wrapper/src/binding.rs index f2339604..e9768889 100644 --- a/scylla-rust-wrapper/src/binding.rs +++ b/scylla-rust-wrapper/src/binding.rs @@ -61,7 +61,7 @@ macro_rules! make_index_binder { #[allow(unused_imports)] use crate::value::CassCqlValue::*; match ($e)($($arg), *) { - Ok(v) => $consume_v(ptr_to_ref_mut(this), index as usize, v), + Ok(v) => $consume_v(BoxFFI::as_mut_ref(this), index as usize, v), Err(e) => e, } } @@ -82,7 +82,7 @@ macro_rules! make_name_binder { use crate::value::CassCqlValue::*; let name = ptr_to_cstr(name).unwrap(); match ($e)($($arg), *) { - Ok(v) => $consume_v(ptr_to_ref_mut(this), name, v), + Ok(v) => $consume_v(BoxFFI::as_mut_ref(this), name, v), Err(e) => e, } } @@ -104,7 +104,7 @@ macro_rules! make_name_n_binder { use crate::value::CassCqlValue::*; let name = ptr_to_cstr_n(name, name_length).unwrap(); match ($e)($($arg), *) { - Ok(v) => $consume_v(ptr_to_ref_mut(this), name, v), + Ok(v) => $consume_v(BoxFFI::as_mut_ref(this), name, v), Err(e) => e, } } @@ -123,7 +123,7 @@ macro_rules! make_appender { #[allow(unused_imports)] use crate::value::CassCqlValue::*; match ($e)($($arg), *) { - Ok(v) => $consume_v(ptr_to_ref_mut(this), v), + Ok(v) => $consume_v(BoxFFI::as_mut_ref(this), v), Err(e) => e, } } @@ -303,7 +303,7 @@ macro_rules! invoke_binder_maker_macro_with_type { $consume_v, $fn, |p: *const crate::collection::CassCollection| { - match std::convert::TryInto::try_into(ptr_to_ref(p)) { + match std::convert::TryInto::try_into(BoxFFI::as_ref(p)) { Ok(v) => Ok(Some(v)), Err(_) => Err(CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE), } @@ -317,7 +317,7 @@ macro_rules! invoke_binder_maker_macro_with_type { $consume_v, $fn, |p: *const crate::tuple::CassTuple| { - Ok(Some(ptr_to_ref(p).into())) + Ok(Some(BoxFFI::as_ref(p).into())) }, [p @ *const crate::tuple::CassTuple] ); @@ -327,7 +327,7 @@ macro_rules! invoke_binder_maker_macro_with_type { $this, $consume_v, $fn, - |p: *const crate::user_type::CassUserType| Ok(Some(ptr_to_ref(p).into())), + |p: *const crate::user_type::CassUserType| Ok(Some(BoxFFI::as_ref(p).into())), [p @ *const crate::user_type::CassUserType] ); }; diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 4f3f7f11..6b479020 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -176,6 +176,8 @@ pub enum CassDataTypeInner { Custom(String), } +impl ArcFFI for CassDataType {} + impl CassDataTypeInner { /// Checks for equality during typechecks. /// @@ -550,7 +552,7 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType { CassDataType::new(inner) } -// Changed return type to const ptr - Arc::into_raw is const. +// Changed return type to const ptr - ArcFFI::into_ptr is const. // It's probably not a good idea - but cppdriver doesn't guarantee // thread safety apart from CassSession and CassFuture. // This comment also applies to other functions that create CassDataType. @@ -576,46 +578,45 @@ pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const t if t < CassValueType::CASS_VALUE_TYPE_LAST_ENTRY => CassDataTypeInner::Value(t), _ => return ptr::null_mut(), }; - - Arc::into_raw(CassDataType::new_arced(inner)) + ArcFFI::into_ptr(CassDataType::new_arced(inner)) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_new_from_existing( data_type: *const CassDataType, ) -> *const CassDataType { - let data_type = ptr_to_ref(data_type); - Arc::into_raw(CassDataType::new_arced(data_type.get_unchecked().clone())) + let data_type = ArcFFI::as_ref(data_type); + ArcFFI::into_ptr(CassDataType::new_arced(data_type.get_unchecked().clone())) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_new_tuple(item_count: size_t) -> *const CassDataType { - Arc::into_raw(CassDataType::new_arced(CassDataTypeInner::Tuple( + ArcFFI::into_ptr(CassDataType::new_arced(CassDataTypeInner::Tuple( Vec::with_capacity(item_count as usize), ))) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_new_udt(field_count: size_t) -> *const CassDataType { - Arc::into_raw(CassDataType::new_arced(CassDataTypeInner::UDT( + ArcFFI::into_ptr(CassDataType::new_arced(CassDataTypeInner::UDT( UDTDataType::with_capacity(field_count as usize), ))) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_free(data_type: *mut CassDataType) { - free_arced(data_type); + ArcFFI::free(data_type); } #[no_mangle] pub unsafe extern "C" fn cass_data_type_type(data_type: *const CassDataType) -> CassValueType { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); data_type.get_unchecked().get_value_type() } #[no_mangle] pub unsafe extern "C" fn cass_data_type_is_frozen(data_type: *const CassDataType) -> cass_bool_t { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); let is_frozen = match data_type.get_unchecked() { CassDataTypeInner::UDT(udt) => udt.frozen, CassDataTypeInner::List { frozen, .. } => *frozen, @@ -633,7 +634,7 @@ pub unsafe extern "C" fn cass_data_type_type_name( type_name: *mut *const c_char, type_name_length: *mut size_t, ) -> CassError { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); match data_type.get_unchecked() { CassDataTypeInner::UDT(UDTDataType { name, .. }) => { write_str_to_c(name, type_name, type_name_length); @@ -657,7 +658,7 @@ pub unsafe extern "C" fn cass_data_type_set_type_name_n( type_name: *const c_char, type_name_length: size_t, ) -> CassError { - let data_type = ptr_to_ref(data_type_raw); + let data_type = ArcFFI::as_ref(data_type_raw); let type_name_string = ptr_to_cstr_n(type_name, type_name_length) .unwrap() .to_string(); @@ -677,7 +678,7 @@ pub unsafe extern "C" fn cass_data_type_keyspace( keyspace: *mut *const c_char, keyspace_length: *mut size_t, ) -> CassError { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); match data_type.get_unchecked() { CassDataTypeInner::UDT(UDTDataType { name, .. }) => { write_str_to_c(name, keyspace, keyspace_length); @@ -701,7 +702,7 @@ pub unsafe extern "C" fn cass_data_type_set_keyspace_n( keyspace: *const c_char, keyspace_length: size_t, ) -> CassError { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); let keyspace_string = ptr_to_cstr_n(keyspace, keyspace_length) .unwrap() .to_string(); @@ -721,7 +722,7 @@ pub unsafe extern "C" fn cass_data_type_class_name( class_name: *mut *const ::std::os::raw::c_char, class_name_length: *mut size_t, ) -> CassError { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); match data_type.get_unchecked() { CassDataTypeInner::Custom(name) => { write_str_to_c(name, class_name, class_name_length); @@ -745,7 +746,7 @@ pub unsafe extern "C" fn cass_data_type_set_class_name_n( class_name: *const ::std::os::raw::c_char, class_name_length: size_t, ) -> CassError { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); let class_string = ptr_to_cstr_n(class_name, class_name_length) .unwrap() .to_string(); @@ -760,7 +761,7 @@ pub unsafe extern "C" fn cass_data_type_set_class_name_n( #[no_mangle] pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDataType) -> size_t { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); match data_type.get_unchecked() { CassDataTypeInner::Value(..) => 0, CassDataTypeInner::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t, @@ -787,14 +788,14 @@ pub unsafe extern "C" fn cass_data_type_sub_data_type( data_type: *const CassDataType, index: size_t, ) -> *const CassDataType { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); let sub_type: Option<&Arc> = data_type.get_unchecked().get_sub_data_type(index as usize); match sub_type { None => std::ptr::null(), // Semantic from cppdriver which also returns non-owning pointer - Some(arc) => Arc::as_ptr(arc), + Some(arc) => ArcFFI::as_ptr(arc), } } @@ -812,12 +813,12 @@ pub unsafe extern "C" fn cass_data_type_sub_data_type_by_name_n( name: *const ::std::os::raw::c_char, name_length: size_t, ) -> *const CassDataType { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); let name_str = ptr_to_cstr_n(name, name_length).unwrap(); match data_type.get_unchecked() { CassDataTypeInner::UDT(udt) => match udt.get_field_by_name(name_str) { None => std::ptr::null(), - Some(t) => Arc::as_ptr(t), + Some(t) => ArcFFI::as_ptr(t), }, _ => std::ptr::null(), } @@ -830,7 +831,7 @@ pub unsafe extern "C" fn cass_data_type_sub_type_name( name: *mut *const ::std::os::raw::c_char, name_length: *mut size_t, ) -> CassError { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); match data_type.get_unchecked() { CassDataTypeInner::UDT(udt) => match udt.field_types.get(index as usize) { None => CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, @@ -848,10 +849,10 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type( data_type: *const CassDataType, sub_data_type: *const CassDataType, ) -> CassError { - let data_type = ptr_to_ref(data_type); + let data_type = ArcFFI::as_ref(data_type); match data_type .get_mut_unchecked() - .add_sub_data_type(clone_arced(sub_data_type)) + .add_sub_data_type(ArcFFI::cloned_from_ptr(sub_data_type)) { Ok(()) => CassError::CASS_OK, Err(e) => e, @@ -875,9 +876,9 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name_n( sub_data_type_raw: *const CassDataType, ) -> CassError { let name_string = ptr_to_cstr_n(name, name_length).unwrap().to_string(); - let sub_data_type = clone_arced(sub_data_type_raw); + let sub_data_type = ArcFFI::cloned_from_ptr(sub_data_type_raw); - let data_type = ptr_to_ref(data_type_raw); + let data_type = ArcFFI::as_ref(data_type_raw); match data_type.get_mut_unchecked() { CassDataTypeInner::UDT(udt_data_type) => { // The Cpp Driver does not check whether field_types size @@ -895,7 +896,7 @@ pub unsafe extern "C" fn cass_data_type_add_sub_value_type( sub_value_type: CassValueType, ) -> CassError { let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); - cass_data_type_add_sub_type(data_type, Arc::as_ptr(&sub_data_type)) + cass_data_type_add_sub_type(data_type, ArcFFI::as_ptr(&sub_data_type)) } #[no_mangle] @@ -905,7 +906,7 @@ pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name( sub_value_type: CassValueType, ) -> CassError { let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); - cass_data_type_add_sub_type_by_name(data_type, name, Arc::as_ptr(&sub_data_type)) + cass_data_type_add_sub_type_by_name(data_type, name, ArcFFI::as_ptr(&sub_data_type)) } #[no_mangle] @@ -916,7 +917,12 @@ pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name_n( sub_value_type: CassValueType, ) -> CassError { let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); - cass_data_type_add_sub_type_by_name_n(data_type, name, name_length, Arc::as_ptr(&sub_data_type)) + cass_data_type_add_sub_type_by_name_n( + data_type, + name, + name_length, + ArcFFI::as_ptr(&sub_data_type), + ) } impl TryFrom for Consistency { diff --git a/scylla-rust-wrapper/src/cluster.rs b/scylla-rust-wrapper/src/cluster.rs index ea2d500c..bde792b8 100644 --- a/scylla-rust-wrapper/src/cluster.rs +++ b/scylla-rust-wrapper/src/cluster.rs @@ -165,6 +165,8 @@ impl CassCluster { } } +impl BoxFFI for CassCluster {} + pub struct CassCustomPayload; // We want to make sure that the returned future does not depend @@ -215,7 +217,7 @@ pub unsafe extern "C" fn cass_cluster_new() -> *mut CassCluster { .keepalive_timeout(DEFAULT_KEEPALIVE_TIMEOUT) }; - Box::into_raw(Box::new(CassCluster { + BoxFFI::into_ptr(Box::new(CassCluster { session_builder: default_session_builder, port: 9042, contact_points: Vec::new(), @@ -233,7 +235,7 @@ pub unsafe extern "C" fn cass_cluster_new() -> *mut CassCluster { #[no_mangle] pub unsafe extern "C" fn cass_cluster_free(cluster: *mut CassCluster) { - free_boxed(cluster); + BoxFFI::free(cluster); } #[no_mangle] @@ -261,7 +263,7 @@ unsafe fn cluster_set_contact_points( contact_points_raw: *const c_char, contact_points_length: size_t, ) -> Result<(), CassError> { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); let mut contact_points = ptr_to_cstr_n(contact_points_raw, contact_points_length) .ok_or(CassError::CASS_ERROR_LIB_BAD_PARAMS)? .split(',') @@ -309,7 +311,7 @@ pub unsafe extern "C" fn cass_cluster_set_application_name_n( app_name: *const c_char, app_name_len: size_t, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); let app_name = ptr_to_cstr_n(app_name, app_name_len).unwrap().to_string(); cluster @@ -333,7 +335,7 @@ pub unsafe extern "C" fn cass_cluster_set_application_version_n( app_version: *const c_char, app_version_len: size_t, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); let app_version = ptr_to_cstr_n(app_version, app_version_len) .unwrap() .to_string(); @@ -350,7 +352,7 @@ pub unsafe extern "C" fn cass_cluster_set_client_id( cluster_raw: *mut CassCluster, client_id: CassUuid, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); let client_uuid: uuid::Uuid = client_id.into(); let client_uuid_str = client_uuid.to_string(); @@ -368,7 +370,7 @@ pub unsafe extern "C" fn cass_cluster_set_use_schema( cluster_raw: *mut CassCluster, enabled: cass_bool_t, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster.session_builder.config.fetch_schema_metadata = enabled != 0; } @@ -377,7 +379,7 @@ pub unsafe extern "C" fn cass_cluster_set_tcp_nodelay( cluster_raw: *mut CassCluster, enabled: cass_bool_t, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster.session_builder.config.tcp_nodelay = enabled != 0; } @@ -387,7 +389,7 @@ pub unsafe extern "C" fn cass_cluster_set_tcp_keepalive( enabled: cass_bool_t, delay_secs: c_uint, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); let enabled = enabled != 0; let tcp_keepalive_interval = enabled.then(|| Duration::from_secs(delay_secs as u64)); @@ -399,7 +401,7 @@ pub unsafe extern "C" fn cass_cluster_set_connection_heartbeat_interval( cluster_raw: *mut CassCluster, interval_secs: c_uint, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); let keepalive_interval = (interval_secs > 0).then(|| Duration::from_secs(interval_secs as u64)); cluster.session_builder.config.keepalive_interval = keepalive_interval; @@ -410,7 +412,7 @@ pub unsafe extern "C" fn cass_cluster_set_connection_idle_timeout( cluster_raw: *mut CassCluster, timeout_secs: c_uint, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); let keepalive_timeout = (timeout_secs > 0).then(|| Duration::from_secs(timeout_secs as u64)); cluster.session_builder.config.keepalive_timeout = keepalive_timeout; @@ -421,7 +423,7 @@ pub unsafe extern "C" fn cass_cluster_set_connect_timeout( cluster_raw: *mut CassCluster, timeout_ms: c_uint, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster.session_builder.config.connect_timeout = Duration::from_millis(timeout_ms.into()); } @@ -430,7 +432,7 @@ pub unsafe extern "C" fn cass_cluster_set_request_timeout( cluster_raw: *mut CassCluster, timeout_ms: c_uint, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); exec_profile_builder_modify(&mut cluster.default_execution_profile_builder, |builder| { // 0 -> no timeout @@ -443,7 +445,7 @@ pub unsafe extern "C" fn cass_cluster_set_max_schema_wait_time( cluster_raw: *mut CassCluster, wait_time_ms: c_uint, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster.session_builder.config.schema_agreement_timeout = Duration::from_millis(wait_time_ms.into()); @@ -454,7 +456,7 @@ pub unsafe extern "C" fn cass_cluster_set_schema_agreement_interval( cluster_raw: *mut CassCluster, interval_ms: c_uint, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster.session_builder.config.schema_agreement_interval = Duration::from_millis(interval_ms.into()); @@ -469,7 +471,7 @@ pub unsafe extern "C" fn cass_cluster_set_port( return CassError::CASS_ERROR_LIB_BAD_PARAMS; } - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster.port = port as u16; CassError::CASS_OK } @@ -501,14 +503,14 @@ pub unsafe extern "C" fn cass_cluster_set_credentials_n( let username = ptr_to_cstr_n(username_raw, username_length).unwrap(); let password = ptr_to_cstr_n(password_raw, password_length).unwrap(); - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster.auth_username = Some(username.to_string()); cluster.auth_password = Some(password.to_string()); } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_load_balance_round_robin(cluster_raw: *mut CassCluster) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster.load_balancing_config.load_balancing_kind = Some(LoadBalancingKind::RoundRobin); } @@ -561,7 +563,7 @@ pub unsafe extern "C" fn cass_cluster_set_load_balance_dc_aware_n( used_hosts_per_remote_dc: c_uint, allow_remote_dcs_for_local_cl: cass_bool_t, ) -> CassError { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); set_load_balance_dc_aware_n( &mut cluster.load_balancing_config, @@ -595,7 +597,7 @@ pub unsafe extern "C" fn cass_cluster_set_load_balance_rack_aware_n( local_rack_raw: *const c_char, local_rack_length: size_t, ) -> CassError { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); set_load_balance_rack_aware_n( &mut cluster.load_balancing_config, @@ -707,7 +709,7 @@ pub unsafe extern "C" fn cass_cluster_set_use_beta_protocol_version( cluster_raw: *mut CassCluster, enable: cass_bool_t, ) -> CassError { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster.use_beta_protocol_version = enable == cass_true; CassError::CASS_OK @@ -718,7 +720,7 @@ pub unsafe extern "C" fn cass_cluster_set_protocol_version( cluster_raw: *mut CassCluster, protocol_version: c_int, ) -> CassError { - let cluster = ptr_to_ref(cluster_raw); + let cluster = BoxFFI::as_ref(cluster_raw); if protocol_version == 4 && !cluster.use_beta_protocol_version { // Rust Driver supports only protocol version 4 @@ -747,7 +749,7 @@ pub unsafe extern "C" fn cass_cluster_set_constant_speculative_execution_policy( return CassError::CASS_ERROR_LIB_BAD_PARAMS; } - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); let policy = SimpleSpeculativeExecutionPolicy { max_retry_count: max_speculative_executions as usize, @@ -765,7 +767,7 @@ pub unsafe extern "C" fn cass_cluster_set_constant_speculative_execution_policy( pub unsafe extern "C" fn cass_cluster_set_no_speculative_execution_policy( cluster_raw: *mut CassCluster, ) -> CassError { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); exec_profile_builder_modify(&mut cluster.default_execution_profile_builder, |builder| { builder.speculative_execution_policy(None) @@ -779,7 +781,7 @@ pub unsafe extern "C" fn cass_cluster_set_token_aware_routing( cluster_raw: *mut CassCluster, enabled: cass_bool_t, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster.load_balancing_config.token_awareness_enabled = enabled != 0; } @@ -788,7 +790,7 @@ pub unsafe extern "C" fn cass_cluster_set_token_aware_routing_shuffle_replicas( cluster_raw: *mut CassCluster, enabled: cass_bool_t, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); cluster .load_balancing_config @@ -800,9 +802,9 @@ pub unsafe extern "C" fn cass_cluster_set_retry_policy( cluster_raw: *mut CassCluster, retry_policy: *const CassRetryPolicy, ) { - let cluster = ptr_to_ref_mut(cluster_raw); + let cluster = BoxFFI::as_mut_ref(cluster_raw); - let retry_policy: Arc = match ptr_to_ref(retry_policy) { + let retry_policy: Arc = match ArcFFI::as_ref(retry_policy) { DefaultRetryPolicy(default) => Arc::clone(default) as _, FallthroughRetryPolicy(fallthrough) => Arc::clone(fallthrough) as _, DowngradingConsistencyRetryPolicy(downgrading) => Arc::clone(downgrading) as _, @@ -815,8 +817,8 @@ pub unsafe extern "C" fn cass_cluster_set_retry_policy( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_ssl(cluster: *mut CassCluster, ssl: *mut CassSsl) { - let cluster_from_raw = ptr_to_ref_mut(cluster); - let cass_ssl = clone_arced(ssl); + let cluster_from_raw = BoxFFI::as_mut_ref(cluster); + let cass_ssl = ArcFFI::cloned_from_ptr(ssl); let ssl_context_builder = SslContextBuilder::from_ptr(cass_ssl.ssl_context); // Reference count is increased as tokio_openssl will try to free `ssl_context` when calling `SSL_free`. @@ -830,7 +832,7 @@ pub unsafe extern "C" fn cass_cluster_set_compression( cluster: *mut CassCluster, compression_type: CassCompressionType, ) { - let cluster_from_raw = ptr_to_ref_mut(cluster); + let cluster_from_raw = BoxFFI::as_mut_ref(cluster); let compression = match compression_type { CassCompressionType::CASS_COMPRESSION_LZ4 => Some(Compression::Lz4), CassCompressionType::CASS_COMPRESSION_SNAPPY => Some(Compression::Snappy), @@ -845,7 +847,7 @@ pub unsafe extern "C" fn cass_cluster_set_latency_aware_routing( cluster: *mut CassCluster, enabled: cass_bool_t, ) { - let cluster = ptr_to_ref_mut(cluster); + let cluster = BoxFFI::as_mut_ref(cluster); cluster.load_balancing_config.latency_awareness_enabled = enabled != 0; } @@ -858,7 +860,7 @@ pub unsafe extern "C" fn cass_cluster_set_latency_aware_routing_settings( update_rate_ms: cass_uint64_t, min_measured: cass_uint64_t, ) { - let cluster = ptr_to_ref_mut(cluster); + let cluster = BoxFFI::as_mut_ref(cluster); cluster.load_balancing_config.latency_awareness_builder = LatencyAwarenessBuilder::new() .exclusion_threshold(exclusion_threshold) .scale(Duration::from_millis(scale_ms)) @@ -872,7 +874,7 @@ pub unsafe extern "C" fn cass_cluster_set_consistency( cluster: *mut CassCluster, consistency: CassConsistency, ) -> CassError { - let cluster = ptr_to_ref_mut(cluster); + let cluster = BoxFFI::as_mut_ref(cluster); let consistency: Consistency = match consistency.try_into() { Ok(c) => c, Err(_) => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -890,7 +892,7 @@ pub unsafe extern "C" fn cass_cluster_set_serial_consistency( cluster: *mut CassCluster, serial_consistency: CassConsistency, ) -> CassError { - let cluster = ptr_to_ref_mut(cluster); + let cluster = BoxFFI::as_mut_ref(cluster); let serial_consistency: SerialConsistency = match serial_consistency.try_into() { Ok(c) => c, Err(_) => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -919,7 +921,7 @@ pub unsafe extern "C" fn cass_cluster_set_execution_profile_n( name_length: size_t, profile: *const CassExecProfile, ) -> CassError { - let cluster = ptr_to_ref_mut(cluster); + let cluster = BoxFFI::as_mut_ref(cluster); let name = if let Some(name) = ptr_to_cstr_n(name, name_length).and_then(|name| name.to_owned().try_into().ok()) { @@ -928,7 +930,7 @@ pub unsafe extern "C" fn cass_cluster_set_execution_profile_n( // Got NULL or empty string, which is invalid name for a profile. return CassError::CASS_ERROR_LIB_BAD_PARAMS; }; - let profile = if let Some(profile) = profile.as_ref() { + let profile = if let Some(profile) = BoxFFI::as_maybe_ref(profile) { profile.clone() } else { return CassError::CASS_ERROR_LIB_BAD_PARAMS; @@ -963,7 +965,7 @@ mod tests { let cluster_raw = cass_cluster_new(); { /* Test valid configurations */ - let cluster = ptr_to_ref(cluster_raw); + let cluster = BoxFFI::as_ref(cluster_raw); { assert_matches!(cluster.load_balancing_config.load_balancing_kind, None); assert!(cluster.load_balancing_config.token_awareness_enabled); @@ -1108,7 +1110,7 @@ mod tests { let exec_profile_raw = cass_execution_profile_new(); { /* Test valid configurations */ - let cluster = ptr_to_ref(cluster_raw); + let cluster = BoxFFI::as_ref(cluster_raw); { assert!(cluster.execution_profile_map.is_empty()); } diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index 074bf5f0..dd6e0f2b 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -29,6 +29,8 @@ pub struct CassCollection { pub items: Vec, } +impl BoxFFI for CassCollection {} + impl CassCollection { fn typecheck_on_append(&self, value: &Option) -> CassError { // See https://github.com/scylladb/cpp-driver/blob/master/src/collection.hpp#L100. @@ -136,7 +138,7 @@ pub unsafe extern "C" fn cass_collection_new( _ => item_count, } as usize; - Box::into_raw(Box::new(CassCollection { + BoxFFI::into_ptr(Box::new(CassCollection { collection_type, data_type: None, capacity, @@ -149,7 +151,7 @@ unsafe extern "C" fn cass_collection_new_from_data_type( data_type: *const CassDataType, item_count: size_t, ) -> *mut CassCollection { - let data_type = clone_arced(data_type); + let data_type = ArcFFI::cloned_from_ptr(data_type); let (capacity, collection_type) = match data_type.get_unchecked() { CassDataTypeInner::List { .. } => { (item_count, CassCollectionType::CASS_COLLECTION_TYPE_LIST) @@ -164,7 +166,7 @@ unsafe extern "C" fn cass_collection_new_from_data_type( }; let capacity = capacity as usize; - Box::into_raw(Box::new(CassCollection { + BoxFFI::into_ptr(Box::new(CassCollection { collection_type, data_type: Some(data_type), capacity, @@ -176,10 +178,10 @@ unsafe extern "C" fn cass_collection_new_from_data_type( unsafe extern "C" fn cass_collection_data_type( collection: *const CassCollection, ) -> *const CassDataType { - let collection_ref = ptr_to_ref(collection); + let collection_ref = BoxFFI::as_ref(collection); match &collection_ref.data_type { - Some(dt) => Arc::as_ptr(dt), + Some(dt) => ArcFFI::as_ptr(dt), None => match collection_ref.collection_type { CassCollectionType::CASS_COLLECTION_TYPE_LIST => &UNTYPED_LIST_TYPE, CassCollectionType::CASS_COLLECTION_TYPE_SET => &UNTYPED_SET_TYPE, @@ -195,7 +197,7 @@ unsafe extern "C" fn cass_collection_data_type( #[no_mangle] pub unsafe extern "C" fn cass_collection_free(collection: *mut CassCollection) { - free_boxed(collection); + BoxFFI::free(collection); } prepare_binders_macro!(@append CassCollection, |collection: &mut CassCollection, v| collection.append_cql_value(v)); @@ -220,9 +222,8 @@ make_binders!(user_type, cass_collection_append_user_type); #[cfg(test)] mod tests { - use std::sync::Arc; - use crate::{ + argconv::ArcFFI, cass_error::CassError, cass_types::{CassDataType, CassDataTypeInner, CassValueType, MapDataType}, collection::{ @@ -269,7 +270,7 @@ mod tests { frozen: false, }); - let dt_ptr = Arc::into_raw(dt); + let dt_ptr = ArcFFI::into_ptr(dt); let untyped_map = cass_collection_new_from_data_type(dt_ptr, 2); assert_cass_error_eq!( @@ -300,7 +301,7 @@ mod tests { frozen: false, }); - let dt_ptr = Arc::into_raw(dt); + let dt_ptr = ArcFFI::into_ptr(dt); let half_typed_map = cass_collection_new_from_data_type(dt_ptr, 2); assert_cass_error_eq!( @@ -343,7 +344,7 @@ mod tests { ), frozen: false, }); - let dt_ptr = Arc::into_raw(dt); + let dt_ptr = ArcFFI::into_ptr(dt); let bool_to_i16_map = cass_collection_new_from_data_type(dt_ptr, 2); // First entry -> typecheck successful. @@ -372,7 +373,7 @@ mod tests { CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE ); - Arc::from_raw(dt_ptr); + ArcFFI::free(dt_ptr); cass_collection_free(bool_to_i16_map); } @@ -398,7 +399,7 @@ mod tests { frozen: false, }); - let dt_ptr = Arc::into_raw(dt); + let dt_ptr = ArcFFI::into_ptr(dt); let untyped_set = cass_collection_new_from_data_type(dt_ptr, 2); assert_cass_error_eq!( @@ -420,7 +421,7 @@ mod tests { ))), frozen: false, }); - let dt_ptr = Arc::into_raw(dt); + let dt_ptr = ArcFFI::into_ptr(dt); let bool_set = cass_collection_new_from_data_type(dt_ptr, 2); assert_cass_error_eq!( @@ -432,7 +433,7 @@ mod tests { CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE ); - Arc::from_raw(dt_ptr); + ArcFFI::free(dt_ptr); cass_collection_free(bool_set); } @@ -458,7 +459,7 @@ mod tests { frozen: false, }); - let dt_ptr = Arc::into_raw(dt); + let dt_ptr = ArcFFI::into_ptr(dt); let untyped_list = cass_collection_new_from_data_type(dt_ptr, 2); assert_cass_error_eq!( @@ -480,7 +481,7 @@ mod tests { ))), frozen: false, }); - let dt_ptr = Arc::into_raw(dt); + let dt_ptr = ArcFFI::into_ptr(dt); let bool_list = cass_collection_new_from_data_type(dt_ptr, 2); assert_cass_error_eq!( @@ -492,7 +493,7 @@ mod tests { CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE ); - Arc::from_raw(dt_ptr); + ArcFFI::free(dt_ptr); cass_collection_free(bool_list); } } diff --git a/scylla-rust-wrapper/src/exec_profile.rs b/scylla-rust-wrapper/src/exec_profile.rs index 79cc46a0..281d9a3f 100644 --- a/scylla-rust-wrapper/src/exec_profile.rs +++ b/scylla-rust-wrapper/src/exec_profile.rs @@ -13,7 +13,7 @@ use scylla::retry_policy::RetryPolicy; use scylla::speculative_execution::SimpleSpeculativeExecutionPolicy; use scylla::statement::Consistency; -use crate::argconv::{free_boxed, ptr_to_cstr_n, ptr_to_ref, ptr_to_ref_mut, strlen}; +use crate::argconv::{ptr_to_cstr_n, strlen, ArcFFI, BoxFFI}; use crate::batch::CassBatch; use crate::cass_error::CassError; use crate::cass_types::CassConsistency; @@ -37,6 +37,8 @@ pub struct CassExecProfile { load_balancing_config: LoadBalancingConfig, } +impl BoxFFI for CassExecProfile {} + impl CassExecProfile { fn new() -> Self { Self { @@ -170,12 +172,12 @@ pub(crate) enum PerStatementExecProfileInner { #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_new() -> *mut CassExecProfile { - Box::into_raw(Box::new(CassExecProfile::new())) + BoxFFI::into_ptr(Box::new(CassExecProfile::new())) } #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_free(profile: *mut CassExecProfile) { - free_boxed(profile); + BoxFFI::free(profile); } /* Exec profiles scope setters */ @@ -194,7 +196,7 @@ pub unsafe extern "C" fn cass_statement_set_execution_profile_n( name: *const c_char, name_length: size_t, ) -> CassError { - let statement = ptr_to_ref_mut(statement); + let statement = BoxFFI::as_mut_ref(statement); let name: Option = ptr_to_cstr_n(name, name_length).and_then(|name| name.to_owned().try_into().ok()); statement.exec_profile = name.map(PerStatementExecProfile::new_unresolved); @@ -216,7 +218,7 @@ pub unsafe extern "C" fn cass_batch_set_execution_profile_n( name: *const c_char, name_length: size_t, ) -> CassError { - let batch = ptr_to_ref_mut(batch); + let batch = BoxFFI::as_mut_ref(batch); let name: Option = ptr_to_cstr_n(name, name_length).and_then(|name| name.to_owned().try_into().ok()); batch.exec_profile = name.map(PerStatementExecProfile::new_unresolved); @@ -249,7 +251,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_consistency( profile: *mut CassExecProfile, consistency: CassConsistency, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); let consistency: Consistency = match consistency.try_into() { Ok(c) => c, Err(_) => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -264,7 +266,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_consistency( pub unsafe extern "C" fn cass_execution_profile_set_no_speculative_execution_policy( profile: *mut CassExecProfile, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); profile_builder.modify_in_place(|builder| builder.speculative_execution_policy(None)); @@ -277,7 +279,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_constant_speculative_executi constant_delay_ms: cass_int64_t, max_speculative_executions: cass_int32_t, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); if constant_delay_ms < 0 || max_speculative_executions < 0 { return CassError::CASS_ERROR_LIB_BAD_PARAMS; } @@ -298,7 +300,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_latency_aware_routing( profile: *mut CassExecProfile, enabled: cass_bool_t, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); profile_builder .load_balancing_config .latency_awareness_enabled = enabled != 0; @@ -315,7 +317,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_latency_aware_routing_settin update_rate_ms: cass_uint64_t, min_measured: cass_uint64_t, ) { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); profile_builder .load_balancing_config .latency_awareness_builder = LatencyAwarenessBuilder::new() @@ -349,7 +351,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_dc_aware_n( used_hosts_per_remote_dc: cass_uint32_t, allow_remote_dcs_for_local_cl: cass_bool_t, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); set_load_balance_dc_aware_n( &mut profile_builder.load_balancing_config, @@ -383,7 +385,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_rack_aware_n( local_rack_raw: *const c_char, local_rack_length: size_t, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); set_load_balance_rack_aware_n( &mut profile_builder.load_balancing_config, @@ -398,7 +400,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_rack_aware_n( pub unsafe extern "C" fn cass_execution_profile_set_load_balance_round_robin( profile: *mut CassExecProfile, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); profile_builder.load_balancing_config.load_balancing_kind = Some(LoadBalancingKind::RoundRobin); CassError::CASS_OK @@ -409,7 +411,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_request_timeout( profile: *mut CassExecProfile, timeout_ms: cass_uint64_t, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); profile_builder.modify_in_place(|builder| { builder.request_timeout(Some(std::time::Duration::from_millis(timeout_ms))) }); @@ -422,12 +424,12 @@ pub unsafe extern "C" fn cass_execution_profile_set_retry_policy( profile: *mut CassExecProfile, retry_policy: *const CassRetryPolicy, ) -> CassError { - let retry_policy: Arc = match ptr_to_ref(retry_policy) { + let retry_policy: Arc = match ArcFFI::as_ref(retry_policy) { DefaultRetryPolicy(default) => Arc::clone(default) as _, FallthroughRetryPolicy(fallthrough) => Arc::clone(fallthrough) as _, DowngradingConsistencyRetryPolicy(downgrading) => Arc::clone(downgrading) as _, }; - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); profile_builder.modify_in_place(|builder| builder.retry_policy(retry_policy)); CassError::CASS_OK @@ -438,7 +440,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_serial_consistency( profile: *mut CassExecProfile, serial_consistency: CassConsistency, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); let maybe_serial_consistency = if serial_consistency == CassConsistency::CASS_CONSISTENCY_UNKNOWN { @@ -459,7 +461,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_token_aware_routing( profile: *mut CassExecProfile, enabled: cass_bool_t, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); profile_builder .load_balancing_config .token_awareness_enabled = enabled != 0; @@ -472,7 +474,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_token_aware_routing_shuffle_ profile: *mut CassExecProfile, enabled: cass_bool_t, ) -> CassError { - let profile_builder = ptr_to_ref_mut(profile); + let profile_builder = BoxFFI::as_mut_ref(profile); profile_builder .load_balancing_config .token_aware_shuffling_replicas_enabled = enabled != 0; @@ -517,7 +519,7 @@ mod tests { let profile_raw = cass_execution_profile_new(); { /* Test valid configurations */ - let profile = ptr_to_ref(profile_raw); + let profile = BoxFFI::as_ref(profile_raw); { assert_matches!(profile.load_balancing_config.load_balancing_kind, None); assert!(profile.load_balancing_config.token_awareness_enabled); @@ -622,8 +624,8 @@ mod tests { { /* Test valid configurations */ - let statement = ptr_to_ref(statement_raw); - let batch = ptr_to_ref(batch_raw); + let statement = BoxFFI::as_ref(statement_raw); + let batch = BoxFFI::as_ref(batch_raw); { assert!(statement.exec_profile.is_none()); assert!(batch.exec_profile.is_none()); diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index 874aacdd..4fa84c3a 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -60,6 +60,8 @@ pub struct CassFuture { wait_for_value: Condvar, } +impl ArcFFI for CassFuture {} + /// An error that can appear during `cass_future_wait_timed`. enum FutureError { TimeoutError, @@ -275,7 +277,7 @@ impl CassFuture { } fn into_raw(self: Arc) -> *const Self { - Arc::into_raw(self) + ArcFFI::into_ptr(self) } } @@ -291,12 +293,12 @@ pub unsafe extern "C" fn cass_future_set_callback( callback: CassFutureCallback, data: *mut ::std::os::raw::c_void, ) -> CassError { - ptr_to_ref(future_raw).set_callback(callback, data) + ArcFFI::as_ref(future_raw).set_callback(callback, data) } #[no_mangle] pub unsafe extern "C" fn cass_future_wait(future_raw: *const CassFuture) { - ptr_to_ref(future_raw).with_waited_result(|_| ()); + ArcFFI::as_ref(future_raw).with_waited_result(|_| ()); } #[no_mangle] @@ -304,14 +306,14 @@ pub unsafe extern "C" fn cass_future_wait_timed( future_raw: *const CassFuture, timeout_us: cass_duration_t, ) -> cass_bool_t { - ptr_to_ref(future_raw) + ArcFFI::as_ref(future_raw) .with_waited_result_timed(|_| (), Duration::from_micros(timeout_us)) .is_ok() as cass_bool_t } #[no_mangle] pub unsafe extern "C" fn cass_future_ready(future_raw: *const CassFuture) -> cass_bool_t { - let state_guard = ptr_to_ref(future_raw).state.lock().unwrap(); + let state_guard = ArcFFI::as_ref(future_raw).state.lock().unwrap(); match state_guard.value { None => cass_false, Some(_) => cass_true, @@ -320,7 +322,7 @@ pub unsafe extern "C" fn cass_future_ready(future_raw: *const CassFuture) -> cas #[no_mangle] pub unsafe extern "C" fn cass_future_error_code(future_raw: *const CassFuture) -> CassError { - ptr_to_ref(future_raw).with_waited_result(|r: &mut CassFutureResult| match r { + ArcFFI::as_ref(future_raw).with_waited_result(|r: &mut CassFutureResult| match r { Ok(CassResultValue::QueryError(err)) => err.to_cass_error(), Err((err, _)) => *err, _ => CassError::CASS_OK, @@ -333,7 +335,7 @@ pub unsafe extern "C" fn cass_future_error_message( message: *mut *const ::std::os::raw::c_char, message_length: *mut size_t, ) { - ptr_to_ref(future).with_waited_state(|state: &mut CassFutureState| { + ArcFFI::as_ref(future).with_waited_state(|state: &mut CassFutureState| { let value = &state.value; let msg = state .err_string @@ -348,49 +350,49 @@ pub unsafe extern "C" fn cass_future_error_message( #[no_mangle] pub unsafe extern "C" fn cass_future_free(future_raw: *const CassFuture) { - free_arced(future_raw); + ArcFFI::free(future_raw); } #[no_mangle] pub unsafe extern "C" fn cass_future_get_result( future_raw: *const CassFuture, ) -> *const CassResult { - ptr_to_ref(future_raw) + ArcFFI::as_ref(future_raw) .with_waited_result(|r: &mut CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::QueryResult(qr) => Some(qr.clone()), _ => None, } }) - .map_or(std::ptr::null(), Arc::into_raw) + .map_or(std::ptr::null(), ArcFFI::into_ptr) } #[no_mangle] pub unsafe extern "C" fn cass_future_get_error_result( future_raw: *const CassFuture, ) -> *const CassErrorResult { - ptr_to_ref(future_raw) + ArcFFI::as_ref(future_raw) .with_waited_result(|r: &mut CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::QueryError(qr) => Some(qr.clone()), _ => None, } }) - .map_or(std::ptr::null(), Arc::into_raw) + .map_or(std::ptr::null(), ArcFFI::into_ptr) } #[no_mangle] pub unsafe extern "C" fn cass_future_get_prepared( future_raw: *mut CassFuture, ) -> *const CassPrepared { - ptr_to_ref(future_raw) + ArcFFI::as_ref(future_raw) .with_waited_result(|r: &mut CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::Prepared(p) => Some(p.clone()), _ => None, } }) - .map_or(std::ptr::null(), Arc::into_raw) + .map_or(std::ptr::null(), ArcFFI::into_ptr) } #[no_mangle] @@ -398,7 +400,7 @@ pub unsafe extern "C" fn cass_future_tracing_id( future: *const CassFuture, tracing_id: *mut CassUuid, ) -> CassError { - ptr_to_ref(future).with_waited_result(|r: &mut CassFutureResult| match r { + ArcFFI::as_ref(future).with_waited_result(|r: &mut CassFutureResult| match r { Ok(CassResultValue::QueryResult(result)) => match result.tracing_id { Some(id) => { *tracing_id = CassUuid::from(id); diff --git a/scylla-rust-wrapper/src/integration_testing.rs b/scylla-rust-wrapper/src/integration_testing.rs index 72526cf8..0fd4007f 100644 --- a/scylla-rust-wrapper/src/integration_testing.rs +++ b/scylla-rust-wrapper/src/integration_testing.rs @@ -1,7 +1,7 @@ use std::ffi::{c_char, CString}; use crate::{ - argconv::ptr_to_ref, + argconv::BoxFFI, cluster::CassCluster, types::{cass_int32_t, cass_uint16_t, size_t}, }; @@ -10,14 +10,14 @@ use crate::{ pub unsafe extern "C" fn testing_cluster_get_connect_timeout( cluster_raw: *const CassCluster, ) -> cass_uint16_t { - let cluster = ptr_to_ref(cluster_raw); + let cluster = BoxFFI::as_ref(cluster_raw); cluster.get_session_config().connect_timeout.as_millis() as cass_uint16_t } #[no_mangle] pub unsafe extern "C" fn testing_cluster_get_port(cluster_raw: *const CassCluster) -> cass_int32_t { - let cluster = ptr_to_ref(cluster_raw); + let cluster = BoxFFI::as_ref(cluster_raw); cluster.get_port() as cass_int32_t } @@ -28,7 +28,7 @@ pub unsafe extern "C" fn testing_cluster_get_contact_points( contact_points: *mut *mut c_char, contact_points_length: *mut size_t, ) { - let cluster = ptr_to_ref(cluster_raw); + let cluster = BoxFFI::as_ref(cluster_raw); let contact_points_string = cluster.get_contact_points().join(","); let length = contact_points_string.len(); diff --git a/scylla-rust-wrapper/src/lib.rs b/scylla-rust-wrapper/src/lib.rs index 9f027fdf..88a6af1f 100644 --- a/scylla-rust-wrapper/src/lib.rs +++ b/scylla-rust-wrapper/src/lib.rs @@ -122,7 +122,7 @@ lazy_static! { // #[no_mangle] // pub extern "C" fn create_foo() -> *mut Foo { -// Box::into_raw(Box::new(Foo)) +// BoxFFI::into_raw(Box::new(Foo)) // } // To borrow (and not free) from C: diff --git a/scylla-rust-wrapper/src/logging.rs b/scylla-rust-wrapper/src/logging.rs index 07ebe4de..c1a43f82 100644 --- a/scylla-rust-wrapper/src/logging.rs +++ b/scylla-rust-wrapper/src/logging.rs @@ -1,4 +1,4 @@ -use crate::argconv::{arr_to_cstr, ptr_to_cstr, ptr_to_ref, str_to_arr}; +use crate::argconv::{arr_to_cstr, ptr_to_cstr, str_to_arr, RefFFI}; use crate::cass_log_types::{CassLogLevel, CassLogMessage}; use crate::types::size_t; use crate::LOGGER; @@ -14,6 +14,8 @@ use tracing_subscriber::layer::Context; use tracing_subscriber::prelude::*; use tracing_subscriber::Layer; +impl RefFFI for CassLogMessage {} + pub type CassLogCallback = Option; @@ -63,7 +65,7 @@ impl TryFrom for Level { pub const CASS_LOG_MAX_MESSAGE_SIZE: usize = 1024; pub unsafe extern "C" fn stderr_log_callback(message: *const CassLogMessage, _data: *mut c_void) { - let message = ptr_to_ref(message); + let message = RefFFI::as_ref(message); eprintln!( "{} [{}] ({}:{}) {}", diff --git a/scylla-rust-wrapper/src/metadata.rs b/scylla-rust-wrapper/src/metadata.rs index 2422cd55..ca969069 100644 --- a/scylla-rust-wrapper/src/metadata.rs +++ b/scylla-rust-wrapper/src/metadata.rs @@ -13,6 +13,8 @@ pub struct CassSchemaMeta { pub keyspaces: HashMap, } +impl BoxFFI for CassSchemaMeta {} + pub struct CassKeyspaceMeta { pub name: String, @@ -22,6 +24,9 @@ pub struct CassKeyspaceMeta { pub views: HashMap>, } +// Owned by CassSchemaMeta +impl RefFFI for CassKeyspaceMeta {} + pub struct CassTableMeta { pub name: String, pub columns_metadata: HashMap, @@ -30,18 +35,29 @@ pub struct CassTableMeta { pub views: HashMap>, } +// Either: +// - owned by CassMaterializedViewMeta - won't be given to user +// - Owned by CassKeyspaceMeta (in Arc), referenced (Weak) by CassMaterializedViewMeta +impl RefFFI for CassTableMeta {} + pub struct CassMaterializedViewMeta { pub name: String, pub view_metadata: CassTableMeta, pub base_table: Weak, } +// Shared ownership by CassKeyspaceMeta and CassTableMeta +impl RefFFI for CassMaterializedViewMeta {} + pub struct CassColumnMeta { pub name: String, pub column_type: CassDataType, pub column_kind: CassColumnType, } +// Owned by CassTableMeta +impl RefFFI for CassColumnMeta {} + pub unsafe fn create_table_metadata( keyspace_name: &str, table_name: &str, @@ -82,7 +98,7 @@ pub unsafe fn create_table_metadata( #[no_mangle] pub unsafe extern "C" fn cass_schema_meta_free(schema_meta: *mut CassSchemaMeta) { - free_boxed(schema_meta) + BoxFFI::free(schema_meta); } #[no_mangle] @@ -103,7 +119,7 @@ pub unsafe extern "C" fn cass_schema_meta_keyspace_by_name_n( return std::ptr::null(); } - let metadata = ptr_to_ref(schema_meta); + let metadata = BoxFFI::as_ref(schema_meta); let keyspace = ptr_to_cstr_n(keyspace_name, keyspace_name_length).unwrap(); let keyspace_meta = metadata.keyspaces.get(keyspace); @@ -120,7 +136,7 @@ pub unsafe extern "C" fn cass_keyspace_meta_name( name: *mut *const c_char, name_length: *mut size_t, ) { - let keyspace_meta = ptr_to_ref(keyspace_meta); + let keyspace_meta = RefFFI::as_ref(keyspace_meta); write_str_to_c(keyspace_meta.name.as_str(), name, name_length) } @@ -142,14 +158,14 @@ pub unsafe extern "C" fn cass_keyspace_meta_user_type_by_name_n( return std::ptr::null(); } - let keyspace_meta = ptr_to_ref(keyspace_meta); + let keyspace_meta = RefFFI::as_ref(keyspace_meta); let user_type_name = ptr_to_cstr_n(type_, type_length).unwrap(); match keyspace_meta .user_defined_type_data_type .get(user_type_name) { - Some(udt) => Arc::as_ptr(udt), + Some(udt) => ArcFFI::as_ptr(udt), None => std::ptr::null(), } } @@ -172,13 +188,13 @@ pub unsafe extern "C" fn cass_keyspace_meta_table_by_name_n( return std::ptr::null(); } - let keyspace_meta = ptr_to_ref(keyspace_meta); + let keyspace_meta = RefFFI::as_ref(keyspace_meta); let table_name = ptr_to_cstr_n(table, table_length).unwrap(); let table_meta = keyspace_meta.tables.get(table_name); match table_meta { - Some(meta) => Arc::as_ptr(meta), + Some(meta) => RefFFI::as_ptr(meta), None => std::ptr::null(), } } @@ -189,13 +205,13 @@ pub unsafe extern "C" fn cass_table_meta_name( name: *mut *const c_char, name_length: *mut size_t, ) { - let table_meta = ptr_to_ref(table_meta); + let table_meta = RefFFI::as_ref(table_meta); write_str_to_c(table_meta.name.as_str(), name, name_length) } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_column_count(table_meta: *const CassTableMeta) -> size_t { - let table_meta = ptr_to_ref(table_meta); + let table_meta = RefFFI::as_ref(table_meta); table_meta.columns_metadata.len() as size_t } @@ -204,7 +220,7 @@ pub unsafe extern "C" fn cass_table_meta_partition_key( table_meta: *const CassTableMeta, index: size_t, ) -> *const CassColumnMeta { - let table_meta = ptr_to_ref(table_meta); + let table_meta = RefFFI::as_ref(table_meta); match table_meta.partition_keys.get(index as usize) { Some(column_name) => match table_meta.columns_metadata.get(column_name) { @@ -219,7 +235,7 @@ pub unsafe extern "C" fn cass_table_meta_partition_key( pub unsafe extern "C" fn cass_table_meta_partition_key_count( table_meta: *const CassTableMeta, ) -> size_t { - let table_meta = ptr_to_ref(table_meta); + let table_meta = RefFFI::as_ref(table_meta); table_meta.partition_keys.len() as size_t } @@ -228,7 +244,7 @@ pub unsafe extern "C" fn cass_table_meta_clustering_key( table_meta: *const CassTableMeta, index: size_t, ) -> *const CassColumnMeta { - let table_meta = ptr_to_ref(table_meta); + let table_meta = RefFFI::as_ref(table_meta); match table_meta.clustering_keys.get(index as usize) { Some(column_name) => match table_meta.columns_metadata.get(column_name) { @@ -243,7 +259,7 @@ pub unsafe extern "C" fn cass_table_meta_clustering_key( pub unsafe extern "C" fn cass_table_meta_clustering_key_count( table_meta: *const CassTableMeta, ) -> size_t { - let table_meta = ptr_to_ref(table_meta); + let table_meta = RefFFI::as_ref(table_meta); table_meta.clustering_keys.len() as size_t } @@ -265,7 +281,7 @@ pub unsafe extern "C" fn cass_table_meta_column_by_name_n( return std::ptr::null(); } - let table_meta = ptr_to_ref(table_meta); + let table_meta = RefFFI::as_ref(table_meta); let column_name = ptr_to_cstr_n(column, column_length).unwrap(); match table_meta.columns_metadata.get(column_name) { @@ -280,7 +296,7 @@ pub unsafe extern "C" fn cass_column_meta_name( name: *mut *const c_char, name_length: *mut size_t, ) { - let column_meta = ptr_to_ref(column_meta); + let column_meta = RefFFI::as_ref(column_meta); write_str_to_c(column_meta.name.as_str(), name, name_length) } @@ -288,7 +304,7 @@ pub unsafe extern "C" fn cass_column_meta_name( pub unsafe extern "C" fn cass_column_meta_data_type( column_meta: *const CassColumnMeta, ) -> *const CassDataType { - let column_meta = ptr_to_ref(column_meta); + let column_meta = RefFFI::as_ref(column_meta); &column_meta.column_type as *const CassDataType } @@ -296,7 +312,7 @@ pub unsafe extern "C" fn cass_column_meta_data_type( pub unsafe extern "C" fn cass_column_meta_type( column_meta: *const CassColumnMeta, ) -> CassColumnType { - let column_meta = ptr_to_ref(column_meta); + let column_meta = RefFFI::as_ref(column_meta); column_meta.column_kind } @@ -318,11 +334,11 @@ pub unsafe extern "C" fn cass_keyspace_meta_materialized_view_by_name_n( return std::ptr::null(); } - let keyspace_meta = ptr_to_ref(keyspace_meta); + let keyspace_meta = RefFFI::as_ref(keyspace_meta); let view_name = ptr_to_cstr_n(view, view_length).unwrap(); match keyspace_meta.views.get(view_name) { - Some(view_meta) => Arc::as_ptr(view_meta), + Some(view_meta) => RefFFI::as_ptr(view_meta.as_ref()), None => std::ptr::null(), } } @@ -345,11 +361,11 @@ pub unsafe extern "C" fn cass_table_meta_materialized_view_by_name_n( return std::ptr::null(); } - let table_meta = ptr_to_ref(table_meta); + let table_meta = RefFFI::as_ref(table_meta); let view_name = ptr_to_cstr_n(view, view_length).unwrap(); match table_meta.views.get(view_name) { - Some(view_meta) => Arc::as_ptr(view_meta), + Some(view_meta) => RefFFI::as_ptr(view_meta.as_ref()), None => std::ptr::null(), } } @@ -358,7 +374,7 @@ pub unsafe extern "C" fn cass_table_meta_materialized_view_by_name_n( pub unsafe extern "C" fn cass_table_meta_materialized_view_count( table_meta: *const CassTableMeta, ) -> size_t { - let table_meta = ptr_to_ref(table_meta); + let table_meta = RefFFI::as_ref(table_meta); table_meta.views.len() as size_t } @@ -367,10 +383,10 @@ pub unsafe extern "C" fn cass_table_meta_materialized_view( table_meta: *const CassTableMeta, index: size_t, ) -> *const CassMaterializedViewMeta { - let table_meta = ptr_to_ref(table_meta); + let table_meta = RefFFI::as_ref(table_meta); match table_meta.views.iter().nth(index as usize) { - Some(view_meta) => Arc::as_ptr(view_meta.1), + Some(view_meta) => RefFFI::as_ptr(view_meta.1.as_ref()), None => std::ptr::null(), } } @@ -393,7 +409,7 @@ pub unsafe extern "C" fn cass_materialized_view_meta_column_by_name_n( return std::ptr::null(); } - let view_meta = ptr_to_ref(view_meta); + let view_meta = RefFFI::as_ref(view_meta); let column_name = ptr_to_cstr_n(column, column_length).unwrap(); match view_meta.view_metadata.columns_metadata.get(column_name) { @@ -408,7 +424,7 @@ pub unsafe extern "C" fn cass_materialized_view_meta_name( name: *mut *const c_char, name_length: *mut size_t, ) { - let view_meta = ptr_to_ref(view_meta); + let view_meta = RefFFI::as_ref(view_meta); write_str_to_c(view_meta.name.as_str(), name, name_length) } @@ -416,7 +432,7 @@ pub unsafe extern "C" fn cass_materialized_view_meta_name( pub unsafe extern "C" fn cass_materialized_view_meta_base_table( view_meta: *const CassMaterializedViewMeta, ) -> *const CassTableMeta { - let view_meta = ptr_to_ref(view_meta); + let view_meta = RefFFI::as_ref(view_meta); view_meta.base_table.as_ptr() } @@ -424,7 +440,7 @@ pub unsafe extern "C" fn cass_materialized_view_meta_base_table( pub unsafe extern "C" fn cass_materialized_view_meta_column_count( view_meta: *const CassMaterializedViewMeta, ) -> size_t { - let view_meta = ptr_to_ref(view_meta); + let view_meta = RefFFI::as_ref(view_meta); view_meta.view_metadata.columns_metadata.len() as size_t } @@ -433,7 +449,7 @@ pub unsafe extern "C" fn cass_materialized_view_meta_column( view_meta: *const CassMaterializedViewMeta, index: size_t, ) -> *const CassColumnMeta { - let view_meta = ptr_to_ref(view_meta); + let view_meta = RefFFI::as_ref(view_meta); match view_meta .view_metadata @@ -450,7 +466,7 @@ pub unsafe extern "C" fn cass_materialized_view_meta_column( pub unsafe extern "C" fn cass_materialized_view_meta_partition_key_count( view_meta: *const CassMaterializedViewMeta, ) -> size_t { - let view_meta = ptr_to_ref(view_meta); + let view_meta = RefFFI::as_ref(view_meta); view_meta.view_metadata.partition_keys.len() as size_t } @@ -458,7 +474,7 @@ pub unsafe extern "C" fn cass_materialized_view_meta_partition_key( view_meta: *const CassMaterializedViewMeta, index: size_t, ) -> *const CassColumnMeta { - let view_meta = ptr_to_ref(view_meta); + let view_meta = RefFFI::as_ref(view_meta); match view_meta.view_metadata.partition_keys.get(index as usize) { Some(column_name) => match view_meta.view_metadata.columns_metadata.get(column_name) { @@ -473,7 +489,7 @@ pub unsafe extern "C" fn cass_materialized_view_meta_partition_key( pub unsafe extern "C" fn cass_materialized_view_meta_clustering_key_count( view_meta: *const CassMaterializedViewMeta, ) -> size_t { - let view_meta = ptr_to_ref(view_meta); + let view_meta = RefFFI::as_ref(view_meta); view_meta.view_metadata.clustering_keys.len() as size_t } @@ -481,7 +497,7 @@ pub unsafe extern "C" fn cass_materialized_view_meta_clustering_key( view_meta: *const CassMaterializedViewMeta, index: size_t, ) -> *const CassColumnMeta { - let view_meta = ptr_to_ref(view_meta); + let view_meta = RefFFI::as_ref(view_meta); match view_meta.view_metadata.clustering_keys.get(index as usize) { Some(column_name) => match view_meta.view_metadata.columns_metadata.get(column_name) { diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/prepared.rs index 53b45db1..3fc4927e 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/prepared.rs @@ -72,23 +72,25 @@ impl CassPrepared { } } +impl ArcFFI for CassPrepared {} + #[no_mangle] pub unsafe extern "C" fn cass_prepared_free(prepared_raw: *const CassPrepared) { - free_arced(prepared_raw); + ArcFFI::free(prepared_raw); } #[no_mangle] pub unsafe extern "C" fn cass_prepared_bind( prepared_raw: *const CassPrepared, ) -> *mut CassStatement { - let prepared: Arc<_> = clone_arced(prepared_raw); + let prepared: Arc<_> = ArcFFI::cloned_from_ptr(prepared_raw); let bound_values_size = prepared.statement.get_variable_col_specs().len(); // cloning prepared statement's arc, because creating CassStatement should not invalidate // the CassPrepared argument let statement = Statement::Prepared(prepared); - Box::into_raw(Box::new(CassStatement { + BoxFFI::into_ptr(Box::new(CassStatement { statement, bound_values: vec![Unset; bound_values_size], paging_state: PagingState::start(), @@ -106,7 +108,7 @@ pub unsafe extern "C" fn cass_prepared_parameter_name( name: *mut *const c_char, name_length: *mut size_t, ) -> CassError { - let prepared = ptr_to_ref(prepared_raw); + let prepared = ArcFFI::as_ref(prepared_raw); match prepared .statement @@ -126,10 +128,10 @@ pub unsafe extern "C" fn cass_prepared_parameter_data_type( prepared_raw: *const CassPrepared, index: size_t, ) -> *const CassDataType { - let prepared = ptr_to_ref(prepared_raw); + let prepared = ArcFFI::as_ref(prepared_raw); match prepared.variable_col_data_types.get(index as usize) { - Some(dt) => Arc::as_ptr(dt), + Some(dt) => ArcFFI::as_ptr(dt), None => std::ptr::null(), } } @@ -148,13 +150,13 @@ pub unsafe extern "C" fn cass_prepared_parameter_data_type_by_name_n( name: *const c_char, name_length: size_t, ) -> *const CassDataType { - let prepared = ptr_to_ref(prepared_raw); + let prepared = ArcFFI::as_ref(prepared_raw); let parameter_name = ptr_to_cstr_n(name, name_length).expect("Prepared parameter name is not UTF-8"); let data_type = prepared.get_variable_data_type_by_name(parameter_name); match data_type { - Some(dt) => Arc::as_ptr(dt), + Some(dt) => ArcFFI::as_ptr(dt), None => std::ptr::null(), } } diff --git a/scylla-rust-wrapper/src/query_error.rs b/scylla-rust-wrapper/src/query_error.rs index f6a91e1c..a6aa376f 100644 --- a/scylla-rust-wrapper/src/query_error.rs +++ b/scylla-rust-wrapper/src/query_error.rs @@ -19,6 +19,8 @@ pub enum CassErrorResult { Deserialization(#[from] DeserializationError), } +impl ArcFFI for CassErrorResult {} + impl From for CassConsistency { fn from(c: Consistency) -> CassConsistency { match c { @@ -55,12 +57,12 @@ impl From<&WriteType> for CassWriteType { #[no_mangle] pub unsafe extern "C" fn cass_error_result_free(error_result: *const CassErrorResult) { - free_arced(error_result); + ArcFFI::free(error_result); } #[no_mangle] pub unsafe extern "C" fn cass_error_result_code(error_result: *const CassErrorResult) -> CassError { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); error_result.to_cass_error() } @@ -68,7 +70,7 @@ pub unsafe extern "C" fn cass_error_result_code(error_result: *const CassErrorRe pub unsafe extern "C" fn cass_error_result_consistency( error_result: *const CassErrorResult, ) -> CassConsistency { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::Unavailable { consistency, .. }, @@ -98,7 +100,7 @@ pub unsafe extern "C" fn cass_error_result_consistency( pub unsafe extern "C" fn cass_error_result_responses_received( error_result: *const CassErrorResult, ) -> cass_int32_t { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError(DbError::Unavailable { alive, .. }, _)) => { *alive @@ -123,7 +125,7 @@ pub unsafe extern "C" fn cass_error_result_responses_received( pub unsafe extern "C" fn cass_error_result_responses_required( error_result: *const CassErrorResult, ) -> cass_int32_t { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError(DbError::Unavailable { required, .. }, _)) => { *required @@ -148,7 +150,7 @@ pub unsafe extern "C" fn cass_error_result_responses_required( pub unsafe extern "C" fn cass_error_result_num_failures( error_result: *const CassErrorResult, ) -> cass_int32_t { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::ReadFailure { numfailures, .. }, @@ -166,7 +168,7 @@ pub unsafe extern "C" fn cass_error_result_num_failures( pub unsafe extern "C" fn cass_error_result_data_present( error_result: *const CassErrorResult, ) -> cass_bool_t { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::ReadTimeout { data_present, .. }, @@ -196,7 +198,7 @@ pub unsafe extern "C" fn cass_error_result_data_present( pub unsafe extern "C" fn cass_error_result_write_type( error_result: *const CassErrorResult, ) -> CassWriteType { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::WriteTimeout { write_type, .. }, @@ -216,7 +218,7 @@ pub unsafe extern "C" fn cass_error_result_keyspace( c_keyspace: *mut *const ::std::os::raw::c_char, c_keyspace_len: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError(DbError::AlreadyExists { keyspace, .. }, _)) => { write_str_to_c(keyspace.as_str(), c_keyspace, c_keyspace_len); @@ -239,7 +241,7 @@ pub unsafe extern "C" fn cass_error_result_table( c_table: *mut *const ::std::os::raw::c_char, c_table_len: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError(DbError::AlreadyExists { table, .. }, _)) => { write_str_to_c(table.as_str(), c_table, c_table_len); @@ -255,7 +257,7 @@ pub unsafe extern "C" fn cass_error_result_function( c_function: *mut *const ::std::os::raw::c_char, c_function_len: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::FunctionFailure { function, .. }, @@ -270,7 +272,7 @@ pub unsafe extern "C" fn cass_error_result_function( #[no_mangle] pub unsafe extern "C" fn cass_error_num_arg_types(error_result: *const CassErrorResult) -> size_t { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::FunctionFailure { arg_types, .. }, @@ -287,7 +289,7 @@ pub unsafe extern "C" fn cass_error_result_arg_type( arg_type: *mut *const ::std::os::raw::c_char, arg_type_length: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ptr_to_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::FunctionFailure { arg_types, .. }, diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 78b8a7e7..5cce03ef 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -95,6 +95,8 @@ impl CassResult { } } +impl ArcFFI for CassResult {} + #[derive(Debug)] pub struct CassResultMetadata { pub col_specs: Vec, @@ -147,6 +149,8 @@ pub struct CassRow { pub result_metadata: Arc, } +impl RefFFI for CassRow {} + pub fn create_cass_rows_from_rows( rows: Vec, metadata: &Arc, @@ -181,6 +185,8 @@ pub struct CassValue { pub value_type: Arc, } +impl RefFFI for CassValue {} + fn create_cass_row_columns(row: Row, metadata: &Arc) -> Vec { row.columns .into_iter() @@ -361,15 +367,17 @@ pub enum CassIterator { CassViewMetaIterator(CassViewMetaIterator), } +impl BoxFFI for CassIterator {} + #[no_mangle] pub unsafe extern "C" fn cass_iterator_free(iterator: *mut CassIterator) { - free_boxed(iterator); + BoxFFI::free(iterator); } // After creating an iterator we have to call next() before accessing the value #[no_mangle] pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass_bool_t { - let mut iter = ptr_to_ref_mut(iterator); + let mut iter = BoxFFI::as_mut_ref(iterator); match &mut iter { CassIterator::CassResultIterator(result_iterator) => { @@ -469,7 +477,7 @@ pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_row(iterator: *const CassIterator) -> *const CassRow { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); // Defined only for result iterator, for other types should return null if let CassIterator::CassResultIterator(result_iterator) = iter { @@ -497,7 +505,7 @@ pub unsafe extern "C" fn cass_iterator_get_row(iterator: *const CassIterator) -> pub unsafe extern "C" fn cass_iterator_get_column( iterator: *const CassIterator, ) -> *const CassValue { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); // Defined only for row iterator, for other types should return null if let CassIterator::CassRowIterator(row_iterator) = iter { @@ -521,7 +529,7 @@ pub unsafe extern "C" fn cass_iterator_get_column( pub unsafe extern "C" fn cass_iterator_get_value( iterator: *const CassIterator, ) -> *const CassValue { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); // Defined only for collections(list, set and map) or tuple iterator, for other types should return null if let CassIterator::CassCollectionIterator(collection_iterator) = iter { @@ -558,7 +566,7 @@ pub unsafe extern "C" fn cass_iterator_get_value( pub unsafe extern "C" fn cass_iterator_get_map_key( iterator: *const CassIterator, ) -> *const CassValue { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); if let CassIterator::CassMapIterator(map_iterator) = iter { let iter_position = match map_iterator.position { @@ -585,7 +593,7 @@ pub unsafe extern "C" fn cass_iterator_get_map_key( pub unsafe extern "C" fn cass_iterator_get_map_value( iterator: *const CassIterator, ) -> *const CassValue { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); if let CassIterator::CassMapIterator(map_iterator) = iter { let iter_position = match map_iterator.position { @@ -614,7 +622,7 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_name( name: *mut *const c_char, name_length: *mut size_t, ) -> CassError { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); if let CassIterator::CassUdtIterator(udt_iterator) = iter { let iter_position = match udt_iterator.position { @@ -647,7 +655,7 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_name( pub unsafe extern "C" fn cass_iterator_get_user_type_field_value( iterator: *const CassIterator, ) -> *const CassValue { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); if let CassIterator::CassUdtIterator(udt_iterator) = iter { let iter_position = match udt_iterator.position { @@ -678,7 +686,7 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_value( pub unsafe extern "C" fn cass_iterator_get_keyspace_meta( iterator: *const CassIterator, ) -> *const CassKeyspaceMeta { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); if let CassIterator::CassSchemaMetaIterator(schema_meta_iterator) = iter { let iter_position = match schema_meta_iterator.position { @@ -705,7 +713,7 @@ pub unsafe extern "C" fn cass_iterator_get_keyspace_meta( pub unsafe extern "C" fn cass_iterator_get_table_meta( iterator: *const CassIterator, ) -> *const CassTableMeta { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); if let CassIterator::CassKeyspaceMetaTableIterator(keyspace_meta_iterator) = iter { let iter_position = match keyspace_meta_iterator.position { @@ -720,7 +728,7 @@ pub unsafe extern "C" fn cass_iterator_get_table_meta( .nth(iter_position); return match table_meta_entry_opt { - Some(table_meta_entry) => Arc::as_ptr(table_meta_entry.1), + Some(table_meta_entry) => RefFFI::as_ptr(table_meta_entry.1.as_ref()), None => std::ptr::null(), }; } @@ -732,7 +740,7 @@ pub unsafe extern "C" fn cass_iterator_get_table_meta( pub unsafe extern "C" fn cass_iterator_get_user_type( iterator: *const CassIterator, ) -> *const CassDataType { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); if let CassIterator::CassKeyspaceMetaUserTypeIterator(keyspace_meta_iterator) = iter { let iter_position = match keyspace_meta_iterator.position { @@ -747,7 +755,7 @@ pub unsafe extern "C" fn cass_iterator_get_user_type( .nth(iter_position); return match udt_to_type_entry_opt { - Some(udt_to_type_entry) => Arc::as_ptr(udt_to_type_entry.1), + Some(udt_to_type_entry) => ArcFFI::as_ptr(udt_to_type_entry.1), None => std::ptr::null(), }; } @@ -759,7 +767,7 @@ pub unsafe extern "C" fn cass_iterator_get_user_type( pub unsafe extern "C" fn cass_iterator_get_column_meta( iterator: *const CassIterator, ) -> *const CassColumnMeta { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); match iter { CassIterator::CassTableMetaIterator(table_meta_iterator) => { @@ -805,7 +813,7 @@ pub unsafe extern "C" fn cass_iterator_get_column_meta( pub unsafe extern "C" fn cass_iterator_get_materialized_view_meta( iterator: *const CassIterator, ) -> *const CassMaterializedViewMeta { - let iter = ptr_to_ref(iterator); + let iter = BoxFFI::as_ref(iterator); match iter { CassIterator::CassKeyspaceMetaViewIterator(keyspace_meta_iterator) => { @@ -817,7 +825,7 @@ pub unsafe extern "C" fn cass_iterator_get_materialized_view_meta( let view_meta_entry_opt = keyspace_meta_iterator.value.views.iter().nth(iter_position); match view_meta_entry_opt { - Some(view_meta_entry) => Arc::as_ptr(view_meta_entry.1), + Some(view_meta_entry) => RefFFI::as_ptr(view_meta_entry.1.as_ref()), None => std::ptr::null(), } } @@ -830,7 +838,7 @@ pub unsafe extern "C" fn cass_iterator_get_materialized_view_meta( let view_meta_entry_opt = table_meta_iterator.value.views.iter().nth(iter_position); match view_meta_entry_opt { - Some(view_meta_entry) => Arc::as_ptr(view_meta_entry.1), + Some(view_meta_entry) => RefFFI::as_ptr(view_meta_entry.1.as_ref()), None => std::ptr::null(), } } @@ -840,26 +848,26 @@ pub unsafe extern "C" fn cass_iterator_get_materialized_view_meta( #[no_mangle] pub unsafe extern "C" fn cass_iterator_from_result(result: *const CassResult) -> *mut CassIterator { - let result_from_raw = clone_arced(result); + let result_from_raw = ArcFFI::cloned_from_ptr(result); let iterator = CassResultIterator { result: result_from_raw, position: None, }; - Box::into_raw(Box::new(CassIterator::CassResultIterator(iterator))) + BoxFFI::into_ptr(Box::new(CassIterator::CassResultIterator(iterator))) } #[no_mangle] pub unsafe extern "C" fn cass_iterator_from_row(row: *const CassRow) -> *mut CassIterator { - let row_from_raw = ptr_to_ref(row); + let row_from_raw = RefFFI::as_ref(row); let iterator = CassRowIterator { row: row_from_raw, position: None, }; - Box::into_raw(Box::new(CassIterator::CassRowIterator(iterator))) + BoxFFI::into_ptr(Box::new(CassIterator::CassRowIterator(iterator))) } #[no_mangle] @@ -872,7 +880,7 @@ pub unsafe extern "C" fn cass_iterator_from_collection( return std::ptr::null_mut(); } - let val = ptr_to_ref(value); + let val = RefFFI::as_ref(value); let item_count = cass_value_item_count(value); let item_count = match cass_value_type(value) { CassValueType::CASS_VALUE_TYPE_MAP => item_count * 2, @@ -885,12 +893,12 @@ pub unsafe extern "C" fn cass_iterator_from_collection( position: None, }; - Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator))) + BoxFFI::into_ptr(Box::new(CassIterator::CassCollectionIterator(iterator))) } #[no_mangle] pub unsafe extern "C" fn cass_iterator_from_tuple(value: *const CassValue) -> *mut CassIterator { - let tuple = ptr_to_ref(value); + let tuple = RefFFI::as_ref(value); if let Some(Value::CollectionValue(Collection::Tuple(val))) = &tuple.value { let item_count = val.len(); @@ -900,7 +908,7 @@ pub unsafe extern "C" fn cass_iterator_from_tuple(value: *const CassValue) -> *m position: None, }; - return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator))); + return BoxFFI::into_ptr(Box::new(CassIterator::CassCollectionIterator(iterator))); } std::ptr::null_mut() @@ -908,7 +916,7 @@ pub unsafe extern "C" fn cass_iterator_from_tuple(value: *const CassValue) -> *m #[no_mangle] pub unsafe extern "C" fn cass_iterator_from_map(value: *const CassValue) -> *mut CassIterator { - let map = ptr_to_ref(value); + let map = RefFFI::as_ref(value); if let Some(Value::CollectionValue(Collection::Map(val))) = &map.value { let item_count = val.len(); @@ -918,7 +926,7 @@ pub unsafe extern "C" fn cass_iterator_from_map(value: *const CassValue) -> *mut position: None, }; - return Box::into_raw(Box::new(CassIterator::CassMapIterator(iterator))); + return BoxFFI::into_ptr(Box::new(CassIterator::CassMapIterator(iterator))); } std::ptr::null_mut() @@ -928,7 +936,7 @@ pub unsafe extern "C" fn cass_iterator_from_map(value: *const CassValue) -> *mut pub unsafe extern "C" fn cass_iterator_fields_from_user_type( value: *const CassValue, ) -> *mut CassIterator { - let udt = ptr_to_ref(value); + let udt = RefFFI::as_ref(value); if let Some(Value::CollectionValue(Collection::UserDefinedType { fields, .. })) = &udt.value { let item_count = fields.len(); @@ -938,7 +946,7 @@ pub unsafe extern "C" fn cass_iterator_fields_from_user_type( position: None, }; - return Box::into_raw(Box::new(CassIterator::CassUdtIterator(iterator))); + return BoxFFI::into_ptr(Box::new(CassIterator::CassUdtIterator(iterator))); } std::ptr::null_mut() @@ -948,7 +956,7 @@ pub unsafe extern "C" fn cass_iterator_fields_from_user_type( pub unsafe extern "C" fn cass_iterator_keyspaces_from_schema_meta( schema_meta: *const CassSchemaMeta, ) -> *mut CassIterator { - let metadata = ptr_to_ref(schema_meta); + let metadata = BoxFFI::as_ref(schema_meta); let iterator = CassSchemaMetaIterator { value: metadata, @@ -956,14 +964,14 @@ pub unsafe extern "C" fn cass_iterator_keyspaces_from_schema_meta( position: None, }; - Box::into_raw(Box::new(CassIterator::CassSchemaMetaIterator(iterator))) + BoxFFI::into_ptr(Box::new(CassIterator::CassSchemaMetaIterator(iterator))) } #[no_mangle] pub unsafe extern "C" fn cass_iterator_tables_from_keyspace_meta( keyspace_meta: *const CassKeyspaceMeta, ) -> *mut CassIterator { - let metadata = ptr_to_ref(keyspace_meta); + let metadata = RefFFI::as_ref(keyspace_meta); let iterator = CassKeyspaceMetaIterator { value: metadata, @@ -971,7 +979,7 @@ pub unsafe extern "C" fn cass_iterator_tables_from_keyspace_meta( position: None, }; - Box::into_raw(Box::new(CassIterator::CassKeyspaceMetaTableIterator( + BoxFFI::into_ptr(Box::new(CassIterator::CassKeyspaceMetaTableIterator( iterator, ))) } @@ -980,7 +988,7 @@ pub unsafe extern "C" fn cass_iterator_tables_from_keyspace_meta( pub unsafe extern "C" fn cass_iterator_materialized_views_from_keyspace_meta( keyspace_meta: *const CassKeyspaceMeta, ) -> *mut CassIterator { - let metadata = ptr_to_ref(keyspace_meta); + let metadata = RefFFI::as_ref(keyspace_meta); let iterator = CassKeyspaceMetaIterator { value: metadata, @@ -988,7 +996,7 @@ pub unsafe extern "C" fn cass_iterator_materialized_views_from_keyspace_meta( position: None, }; - Box::into_raw(Box::new(CassIterator::CassKeyspaceMetaViewIterator( + BoxFFI::into_ptr(Box::new(CassIterator::CassKeyspaceMetaViewIterator( iterator, ))) } @@ -997,7 +1005,7 @@ pub unsafe extern "C" fn cass_iterator_materialized_views_from_keyspace_meta( pub unsafe extern "C" fn cass_iterator_user_types_from_keyspace_meta( keyspace_meta: *const CassKeyspaceMeta, ) -> *mut CassIterator { - let metadata = ptr_to_ref(keyspace_meta); + let metadata = RefFFI::as_ref(keyspace_meta); let iterator = CassKeyspaceMetaIterator { value: metadata, @@ -1005,7 +1013,7 @@ pub unsafe extern "C" fn cass_iterator_user_types_from_keyspace_meta( position: None, }; - Box::into_raw(Box::new(CassIterator::CassKeyspaceMetaUserTypeIterator( + BoxFFI::into_ptr(Box::new(CassIterator::CassKeyspaceMetaUserTypeIterator( iterator, ))) } @@ -1014,7 +1022,7 @@ pub unsafe extern "C" fn cass_iterator_user_types_from_keyspace_meta( pub unsafe extern "C" fn cass_iterator_columns_from_table_meta( table_meta: *const CassTableMeta, ) -> *mut CassIterator { - let metadata = ptr_to_ref(table_meta); + let metadata = RefFFI::as_ref(table_meta); let iterator = CassTableMetaIterator { value: metadata, @@ -1022,13 +1030,13 @@ pub unsafe extern "C" fn cass_iterator_columns_from_table_meta( position: None, }; - Box::into_raw(Box::new(CassIterator::CassTableMetaIterator(iterator))) + BoxFFI::into_ptr(Box::new(CassIterator::CassTableMetaIterator(iterator))) } pub unsafe extern "C" fn cass_iterator_materialized_views_from_table_meta( table_meta: *const CassTableMeta, ) -> *mut CassIterator { - let metadata = ptr_to_ref(table_meta); + let metadata = RefFFI::as_ref(table_meta); let iterator = CassTableMetaIterator { value: metadata, @@ -1036,13 +1044,13 @@ pub unsafe extern "C" fn cass_iterator_materialized_views_from_table_meta( position: None, }; - Box::into_raw(Box::new(CassIterator::CassTableMetaIterator(iterator))) + BoxFFI::into_ptr(Box::new(CassIterator::CassTableMetaIterator(iterator))) } pub unsafe extern "C" fn cass_iterator_columns_from_materialized_view_meta( view_meta: *const CassMaterializedViewMeta, ) -> *mut CassIterator { - let metadata = ptr_to_ref(view_meta); + let metadata = RefFFI::as_ref(view_meta); let iterator = CassViewMetaIterator { value: metadata, @@ -1050,17 +1058,17 @@ pub unsafe extern "C" fn cass_iterator_columns_from_materialized_view_meta( position: None, }; - Box::into_raw(Box::new(CassIterator::CassViewMetaIterator(iterator))) + BoxFFI::into_ptr(Box::new(CassIterator::CassViewMetaIterator(iterator))) } #[no_mangle] pub unsafe extern "C" fn cass_result_free(result_raw: *const CassResult) { - free_arced(result_raw); + ArcFFI::free(result_raw); } #[no_mangle] pub unsafe extern "C" fn cass_result_has_more_pages(result: *const CassResult) -> cass_bool_t { - let result = ptr_to_ref(result); + let result = ArcFFI::as_ref(result); (!result.paging_state_response.finished()) as cass_bool_t } @@ -1069,7 +1077,7 @@ pub unsafe extern "C" fn cass_row_get_column( row_raw: *const CassRow, index: size_t, ) -> *const CassValue { - let row: &CassRow = ptr_to_ref(row_raw); + let row: &CassRow = RefFFI::as_ref(row_raw); let index_usize: usize = index.try_into().unwrap(); let column_value = match row.columns.get(index_usize) { @@ -1097,7 +1105,7 @@ pub unsafe extern "C" fn cass_row_get_column_by_name_n( name: *const c_char, name_length: size_t, ) -> *const CassValue { - let row_from_raw = ptr_to_ref(row); + let row_from_raw = RefFFI::as_ref(row); let mut name_str = ptr_to_cstr_n(name, name_length).unwrap(); let mut is_case_sensitive = false; @@ -1130,7 +1138,7 @@ pub unsafe extern "C" fn cass_result_column_name( name: *mut *const c_char, name_length: *mut size_t, ) -> CassError { - let result_from_raw = ptr_to_ref(result); + let result_from_raw = ArcFFI::as_ref(result); let index_usize: usize = index.try_into().unwrap(); let CassResultKind::Rows(CassRowsResult { metadata, .. }) = &result_from_raw.kind else { @@ -1165,7 +1173,7 @@ pub unsafe extern "C" fn cass_result_column_data_type( result: *const CassResult, index: size_t, ) -> *const CassDataType { - let result_from_raw: &CassResult = ptr_to_ref(result); + let result_from_raw: &CassResult = ArcFFI::as_ref(result); let index_usize: usize = index .try_into() .expect("Provided index is out of bounds. Max possible value is usize::MAX"); @@ -1177,22 +1185,22 @@ pub unsafe extern "C" fn cass_result_column_data_type( metadata .col_specs .get(index_usize) - .map(|col_spec| Arc::as_ptr(&col_spec.data_type)) + .map(|col_spec| ArcFFI::as_ptr(&col_spec.data_type)) .unwrap_or(std::ptr::null()) } #[no_mangle] pub unsafe extern "C" fn cass_value_type(value: *const CassValue) -> CassValueType { - let value_from_raw = ptr_to_ref(value); + let value_from_raw = RefFFI::as_ref(value); - cass_data_type_type(Arc::as_ptr(&value_from_raw.value_type)) + cass_data_type_type(ArcFFI::as_ptr(&value_from_raw.value_type)) } #[no_mangle] pub unsafe extern "C" fn cass_value_data_type(value: *const CassValue) -> *const CassDataType { - let value_from_raw = ptr_to_ref(value); + let value_from_raw = RefFFI::as_ref(value); - Arc::as_ptr(&value_from_raw.value_type) + ArcFFI::as_ptr(&value_from_raw.value_type) } macro_rules! val_ptr_to_ref_ensure_non_null { @@ -1200,7 +1208,7 @@ macro_rules! val_ptr_to_ref_ensure_non_null { if $ptr.is_null() { return CassError::CASS_ERROR_LIB_NULL_VALUE; } - ptr_to_ref($ptr) + RefFFI::as_ref($ptr) }}; } @@ -1373,7 +1381,7 @@ pub unsafe extern "C" fn cass_value_get_decimal( varint_size: *mut size_t, scale: *mut cass_int32_t, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = RefFFI::as_ref(value); let decimal = match &val.value { Some(Value::RegularValue(CqlValue::Decimal(decimal))) => decimal, Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, @@ -1464,13 +1472,13 @@ pub unsafe extern "C" fn cass_value_get_bytes( #[no_mangle] pub unsafe extern "C" fn cass_value_is_null(value: *const CassValue) -> cass_bool_t { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = RefFFI::as_ref(value); val.value.is_none() as cass_bool_t } #[no_mangle] pub unsafe extern "C" fn cass_value_is_collection(value: *const CassValue) -> cass_bool_t { - let val = ptr_to_ref(value); + let val = RefFFI::as_ref(value); matches!( val.value_type.get_unchecked().get_value_type(), @@ -1482,7 +1490,7 @@ pub unsafe extern "C" fn cass_value_is_collection(value: *const CassValue) -> ca #[no_mangle] pub unsafe extern "C" fn cass_value_is_duration(value: *const CassValue) -> cass_bool_t { - let val = ptr_to_ref(value); + let val = RefFFI::as_ref(value); (val.value_type.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION) as cass_bool_t @@ -1490,7 +1498,7 @@ pub unsafe extern "C" fn cass_value_is_duration(value: *const CassValue) -> cass #[no_mangle] pub unsafe extern "C" fn cass_value_item_count(collection: *const CassValue) -> size_t { - let val = ptr_to_ref(collection); + let val = RefFFI::as_ref(collection); match &val.value { Some(Value::CollectionValue(Collection::List(list))) => list.len() as size_t, @@ -1508,7 +1516,7 @@ pub unsafe extern "C" fn cass_value_item_count(collection: *const CassValue) -> pub unsafe extern "C" fn cass_value_primary_sub_type( collection: *const CassValue, ) -> CassValueType { - let val = ptr_to_ref(collection); + let val = RefFFI::as_ref(collection); match val.value_type.get_unchecked() { CassDataTypeInner::List { @@ -1527,7 +1535,7 @@ pub unsafe extern "C" fn cass_value_primary_sub_type( pub unsafe extern "C" fn cass_value_secondary_sub_type( collection: *const CassValue, ) -> CassValueType { - let val = ptr_to_ref(collection); + let val = RefFFI::as_ref(collection); match val.value_type.get_unchecked() { CassDataTypeInner::Map { @@ -1540,7 +1548,7 @@ pub unsafe extern "C" fn cass_value_secondary_sub_type( #[no_mangle] pub unsafe extern "C" fn cass_result_row_count(result_raw: *const CassResult) -> size_t { - let result = ptr_to_ref(result_raw); + let result = ArcFFI::as_ref(result_raw); let CassResultKind::Rows(CassRowsResult { rows, .. }) = &result.kind else { return 0; @@ -1551,7 +1559,7 @@ pub unsafe extern "C" fn cass_result_row_count(result_raw: *const CassResult) -> #[no_mangle] pub unsafe extern "C" fn cass_result_column_count(result_raw: *const CassResult) -> size_t { - let result = ptr_to_ref(result_raw); + let result = ArcFFI::as_ref(result_raw); let CassResultKind::Rows(CassRowsResult { metadata, .. }) = &result.kind else { return 0; @@ -1562,7 +1570,7 @@ pub unsafe extern "C" fn cass_result_column_count(result_raw: *const CassResult) #[no_mangle] pub unsafe extern "C" fn cass_result_first_row(result_raw: *const CassResult) -> *const CassRow { - let result = ptr_to_ref(result_raw); + let result = ArcFFI::as_ref(result_raw); let CassResultKind::Rows(CassRowsResult { rows, .. }) = &result.kind else { return std::ptr::null(); @@ -1583,7 +1591,7 @@ pub unsafe extern "C" fn cass_result_paging_state_token( return CassError::CASS_ERROR_LIB_NO_PAGING_STATE; } - let result_from_raw = ptr_to_ref(result); + let result_from_raw = ArcFFI::as_ref(result); match &result_from_raw.paging_state_response { PagingStateResponse::HasMorePages { state } => match state.as_bytes_slice() { @@ -1615,11 +1623,12 @@ mod tests { }; use crate::{ + argconv::ArcFFI, cass_error::CassError, cass_types::{CassDataType, CassDataTypeInner, CassValueType}, query_result::{ cass_result_column_data_type, cass_result_column_name, cass_result_first_row, - ptr_to_cstr_n, ptr_to_ref, size_t, + ptr_to_cstr_n, size_t, }, }; @@ -1723,21 +1732,24 @@ mod tests { // cass_result_column_data_type test { - let first_col_data_type = ptr_to_ref(cass_result_column_data_type(result_ptr, 0)); + let first_col_data_type = + ArcFFI::as_ref(cass_result_column_data_type(result_ptr, 0)); assert_eq!( &CassDataType::new(CassDataTypeInner::Value( CassValueType::CASS_VALUE_TYPE_BIGINT )), first_col_data_type ); - let second_col_data_type = ptr_to_ref(cass_result_column_data_type(result_ptr, 1)); + let second_col_data_type = + ArcFFI::as_ref(cass_result_column_data_type(result_ptr, 1)); assert_eq!( &CassDataType::new(CassDataTypeInner::Value( CassValueType::CASS_VALUE_TYPE_VARINT )), second_col_data_type ); - let third_col_data_type = ptr_to_ref(cass_result_column_data_type(result_ptr, 2)); + let third_col_data_type = + ArcFFI::as_ref(cass_result_column_data_type(result_ptr, 2)); assert_eq!( &CassDataType::new(CassDataTypeInner::List { typ: Some(CassDataType::new_arced(CassDataTypeInner::Value( diff --git a/scylla-rust-wrapper/src/retry_policy.rs b/scylla-rust-wrapper/src/retry_policy.rs index 45884ed2..6945c32a 100644 --- a/scylla-rust-wrapper/src/retry_policy.rs +++ b/scylla-rust-wrapper/src/retry_policy.rs @@ -1,8 +1,9 @@ -use crate::argconv::free_boxed; use scylla::retry_policy::{DefaultRetryPolicy, FallthroughRetryPolicy}; use scylla::transport::downgrading_consistency_retry_policy::DowngradingConsistencyRetryPolicy; use std::sync::Arc; +use crate::argconv::ArcFFI; + pub enum RetryPolicy { DefaultRetryPolicy(Arc), FallthroughRetryPolicy(Arc), @@ -11,28 +12,30 @@ pub enum RetryPolicy { pub type CassRetryPolicy = RetryPolicy; +impl ArcFFI for CassRetryPolicy {} + #[no_mangle] pub extern "C" fn cass_retry_policy_default_new() -> *const CassRetryPolicy { - Box::into_raw(Box::new(RetryPolicy::DefaultRetryPolicy(Arc::new( + ArcFFI::into_ptr(Arc::new(RetryPolicy::DefaultRetryPolicy(Arc::new( DefaultRetryPolicy, )))) } #[no_mangle] pub extern "C" fn cass_retry_policy_downgrading_consistency_new() -> *const CassRetryPolicy { - Box::into_raw(Box::new(RetryPolicy::DowngradingConsistencyRetryPolicy( + ArcFFI::into_ptr(Arc::new(RetryPolicy::DowngradingConsistencyRetryPolicy( Arc::new(DowngradingConsistencyRetryPolicy), ))) } #[no_mangle] pub extern "C" fn cass_retry_policy_fallthrough_new() -> *const CassRetryPolicy { - Box::into_raw(Box::new(RetryPolicy::FallthroughRetryPolicy(Arc::new( + ArcFFI::into_ptr(Arc::new(RetryPolicy::FallthroughRetryPolicy(Arc::new( FallthroughRetryPolicy, )))) } #[no_mangle] -pub unsafe extern "C" fn cass_retry_policy_free(retry_policy: *mut CassRetryPolicy) { - free_boxed(retry_policy); +pub unsafe extern "C" fn cass_retry_policy_free(retry_policy: *const CassRetryPolicy) { + ArcFFI::free(retry_policy); } diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 8010adf5..1517cebc 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -139,10 +139,12 @@ impl CassSessionInner { pub type CassSession = RwLock>; +impl ArcFFI for CassSession {} + #[no_mangle] pub unsafe extern "C" fn cass_session_new() -> *mut CassSession { let session = Arc::new(RwLock::new(None::)); - Arc::into_raw(session) as *mut CassSession + ArcFFI::into_ptr(session) as *mut CassSession } #[no_mangle] @@ -150,8 +152,8 @@ pub unsafe extern "C" fn cass_session_connect( session_raw: *mut CassSession, cluster_raw: *const CassCluster, ) -> *const CassFuture { - let session_opt = ptr_to_ref(session_raw); - let cluster: &CassCluster = ptr_to_ref(cluster_raw); + let session_opt = ArcFFI::as_ref(session_raw); + let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw); CassSessionInner::connect(session_opt, cluster, None) } @@ -172,8 +174,8 @@ pub unsafe extern "C" fn cass_session_connect_keyspace_n( keyspace: *const c_char, keyspace_length: size_t, ) -> *const CassFuture { - let session_opt = ptr_to_ref(session_raw); - let cluster: &CassCluster = ptr_to_ref(cluster_raw); + let session_opt = ArcFFI::as_ref(session_raw); + let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw); let keyspace = ptr_to_cstr_n(keyspace, keyspace_length).map(ToOwned::to_owned); CassSessionInner::connect(session_opt, cluster, keyspace) @@ -184,8 +186,8 @@ pub unsafe extern "C" fn cass_session_execute_batch( session_raw: *mut CassSession, batch_raw: *const CassBatch, ) -> *const CassFuture { - let session_opt = ptr_to_ref(session_raw); - let batch_from_raw = ptr_to_ref(batch_raw); + let session_opt = ArcFFI::as_ref(session_raw); + let batch_from_raw = BoxFFI::as_ref(batch_raw); let mut state = batch_from_raw.state.clone(); let request_timeout_ms = batch_from_raw.batch_request_timeout_ms; @@ -250,10 +252,10 @@ pub unsafe extern "C" fn cass_session_execute( session_raw: *mut CassSession, statement_raw: *const CassStatement, ) -> *const CassFuture { - let session_opt = ptr_to_ref(session_raw); + let session_opt = ArcFFI::as_ref(session_raw); // DO NOT refer to `statement_opt` inside the async block, as I've done just to face a segfault. - let statement_opt = ptr_to_ref(statement_raw); + let statement_opt = BoxFFI::as_ref(statement_raw); let paging_state = statement_opt.paging_state.clone(); let paging_enabled = statement_opt.paging_enabled; let bound_values = statement_opt.bound_values.clone(); @@ -377,8 +379,8 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing( cass_session: *mut CassSession, statement: *const CassStatement, ) -> *const CassFuture { - let session = ptr_to_ref(cass_session); - let cass_statement = ptr_to_ref(statement); + let session = ArcFFI::as_ref(cass_session); + let cass_statement = BoxFFI::as_ref(statement); let statement = cass_statement.statement.clone(); CassFuture::make_raw(async move { @@ -429,7 +431,7 @@ pub unsafe extern "C" fn cass_session_prepare_n( // There is a test for this: `NullStringApiArgsTest.Integration_Cassandra_PrepareNullQuery`. .unwrap_or_default(); let query = Query::new(query_str.to_string()); - let cass_session: &CassSession = ptr_to_ref(cass_session_raw); + let cass_session = ArcFFI::as_ref(cass_session_raw); CassFuture::make_raw(async move { let session_guard = cass_session.read().await; @@ -457,12 +459,12 @@ pub unsafe extern "C" fn cass_session_prepare_n( #[no_mangle] pub unsafe extern "C" fn cass_session_free(session_raw: *mut CassSession) { - free_arced(session_raw); + ArcFFI::free(session_raw); } #[no_mangle] pub unsafe extern "C" fn cass_session_close(session: *mut CassSession) -> *const CassFuture { - let session_opt = ptr_to_ref(session); + let session_opt = ArcFFI::as_ref(session); CassFuture::make_raw(async move { let mut session_guard = session_opt.write().await; @@ -481,7 +483,7 @@ pub unsafe extern "C" fn cass_session_close(session: *mut CassSession) -> *const #[no_mangle] pub unsafe extern "C" fn cass_session_get_client_id(session: *const CassSession) -> CassUuid { - let cass_session = ptr_to_ref(session); + let cass_session = ArcFFI::as_ref(session); let client_id: uuid::Uuid = cass_session.blocking_read().as_ref().unwrap().client_id; client_id.into() @@ -491,7 +493,7 @@ pub unsafe extern "C" fn cass_session_get_client_id(session: *const CassSession) pub unsafe extern "C" fn cass_session_get_schema_meta( session: *const CassSession, ) -> *const CassSchemaMeta { - let cass_session = ptr_to_ref(session); + let cass_session = ArcFFI::as_ref(session); let mut keyspaces: HashMap = HashMap::new(); for (keyspace_name, keyspace) in cass_session @@ -565,7 +567,7 @@ pub unsafe extern "C" fn cass_session_get_schema_meta( ); } - Box::into_raw(Box::new(CassSchemaMeta { keyspaces })) + BoxFFI::into_ptr(Box::new(CassSchemaMeta { keyspaces })) } #[cfg(test)] @@ -580,7 +582,7 @@ mod tests { use super::*; use crate::{ - argconv::{make_c_str, ptr_to_ref}, + argconv::make_c_str, batch::{ cass_batch_add_statement, cass_batch_free, cass_batch_new, cass_batch_set_retry_policy, }, @@ -721,7 +723,7 @@ mod tests { cass_future_wait_check_and_free(cass_session_connect(session_raw, cluster_raw)); // Initially, the profile map is empty. - assert!(ptr_to_ref(session_raw) + assert!(ArcFFI::as_ref(session_raw) .blocking_read() .as_ref() .unwrap() @@ -730,7 +732,7 @@ mod tests { cass_cluster_set_execution_profile(cluster_raw, make_c_str!("prof"), profile_raw); // Mutations in cluster do not affect the session that was connected before. - assert!(ptr_to_ref(session_raw) + assert!(ArcFFI::as_ref(session_raw) .blocking_read() .as_ref() .unwrap() @@ -741,7 +743,7 @@ mod tests { // Mutations in cluster are now propagated to the session. cass_future_wait_check_and_free(cass_session_connect(session_raw, cluster_raw)); - let profile_map_keys = ptr_to_ref(session_raw) + let profile_map_keys = ArcFFI::as_ref(session_raw) .blocking_read() .as_ref() .unwrap() @@ -827,8 +829,8 @@ mod tests { cass_future_wait_check_and_free(cass_session_connect(session_raw, cluster_raw)); { /* Test valid configurations */ - let statement = ptr_to_ref(statement_raw); - let batch = ptr_to_ref(batch_raw); + let statement = BoxFFI::as_ref(statement_raw); + let batch = BoxFFI::as_ref(batch_raw); { assert!(statement.exec_profile.is_none()); assert!(batch.exec_profile.is_none()); diff --git a/scylla-rust-wrapper/src/ssl.rs b/scylla-rust-wrapper/src/ssl.rs index 2e98370f..ba14a24b 100644 --- a/scylla-rust-wrapper/src/ssl.rs +++ b/scylla-rust-wrapper/src/ssl.rs @@ -1,4 +1,4 @@ -use crate::argconv::{clone_arced, free_arced}; +use crate::argconv::ArcFFI; use crate::cass_error::CassError; use crate::types::size_t; use libc::{c_int, strlen}; @@ -19,6 +19,8 @@ pub struct CassSsl { pub(crate) trusted_store: *mut X509_STORE, } +impl ArcFFI for CassSsl {} + pub const CASS_SSL_VERIFY_NONE: i32 = 0x00; pub const CASS_SSL_VERIFY_PEER_CERT: i32 = 0x01; pub const CASS_SSL_VERIFY_PEER_IDENTITY: i32 = 0x02; @@ -43,7 +45,7 @@ pub unsafe extern "C" fn cass_ssl_new_no_lib_init() -> *const CassSsl { trusted_store, }; - Arc::into_raw(Arc::new(ssl)) + ArcFFI::into_ptr(Arc::new(ssl)) } // This is required for the type system to impl Send + Sync for Arc. @@ -63,7 +65,7 @@ impl Drop for CassSsl { #[no_mangle] pub unsafe extern "C" fn cass_ssl_free(ssl: *mut CassSsl) { - free_arced(ssl); + ArcFFI::free(ssl); } unsafe extern "C" fn pem_password_callback( @@ -110,7 +112,7 @@ pub unsafe extern "C" fn cass_ssl_add_trusted_cert_n( cert: *const c_char, cert_length: size_t, ) -> CassError { - let ssl = clone_arced(ssl); + let ssl = ArcFFI::cloned_from_ptr(ssl); let bio = BIO_new_mem_buf(cert as *const c_void, cert_length.try_into().unwrap()); if bio.is_null() { @@ -138,7 +140,7 @@ pub unsafe extern "C" fn cass_ssl_add_trusted_cert_n( #[no_mangle] pub unsafe extern "C" fn cass_ssl_set_verify_flags(ssl: *mut CassSsl, flags: i32) { - let ssl = clone_arced(ssl); + let ssl = ArcFFI::cloned_from_ptr(ssl); match flags { CASS_SSL_VERIFY_NONE => { @@ -176,7 +178,7 @@ pub unsafe extern "C" fn cass_ssl_set_cert_n( cert: *const c_char, cert_length: size_t, ) -> CassError { - let ssl = clone_arced(ssl); + let ssl = ArcFFI::cloned_from_ptr(ssl); let bio = BIO_new_mem_buf(cert as *const c_void, cert_length.try_into().unwrap()); if bio.is_null() { @@ -269,7 +271,7 @@ pub unsafe extern "C" fn cass_ssl_set_private_key_n( password: *mut c_char, _password_length: size_t, ) -> CassError { - let ssl = clone_arced(ssl); + let ssl = ArcFFI::cloned_from_ptr(ssl); let bio = BIO_new_mem_buf(key as *const c_void, key_length.try_into().unwrap()); if bio.is_null() { diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index 22c4f23b..7627d116 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -42,6 +42,8 @@ pub struct CassStatement { pub(crate) exec_profile: Option, } +impl BoxFFI for CassStatement {} + impl CassStatement { fn bind_cql_value(&mut self, index: usize, value: Option) -> CassError { let (bound_value, maybe_data_type) = match &self.statement { @@ -182,7 +184,7 @@ pub unsafe extern "C" fn cass_statement_new_n( name_to_bound_index: HashMap::with_capacity(parameter_count as usize), }; - Box::into_raw(Box::new(CassStatement { + BoxFFI::into_ptr(Box::new(CassStatement { statement: Statement::Simple(simple_query), bound_values: vec![Unset; parameter_count as usize], paging_state: PagingState::start(), @@ -195,7 +197,7 @@ pub unsafe extern "C" fn cass_statement_new_n( #[no_mangle] pub unsafe extern "C" fn cass_statement_free(statement_raw: *mut CassStatement) { - free_boxed(statement_raw); + BoxFFI::free(statement_raw); } #[no_mangle] @@ -206,7 +208,7 @@ pub unsafe extern "C" fn cass_statement_set_consistency( let consistency_opt = get_consistency_from_cass_consistency(consistency); if let Some(consistency) = consistency_opt { - match &mut ptr_to_ref_mut(statement).statement { + match &mut BoxFFI::as_mut_ref(statement).statement { Statement::Simple(inner) => inner.query.set_consistency(consistency), Statement::Prepared(inner) => { Arc::make_mut(inner).statement.set_consistency(consistency) @@ -222,7 +224,7 @@ pub unsafe extern "C" fn cass_statement_set_paging_size( statement_raw: *mut CassStatement, page_size: c_int, ) -> CassError { - let statement = ptr_to_ref_mut(statement_raw); + let statement = BoxFFI::as_mut_ref(statement_raw); if page_size <= 0 { // Cpp driver sets the page size flag only for positive page size provided by user. statement.paging_enabled = false; @@ -242,8 +244,8 @@ pub unsafe extern "C" fn cass_statement_set_paging_state( statement: *mut CassStatement, result: *const CassResult, ) -> CassError { - let statement = ptr_to_ref_mut(statement); - let result = ptr_to_ref(result); + let statement = BoxFFI::as_mut_ref(statement); + let result = ArcFFI::as_ref(result); match &result.paging_state_response { PagingStateResponse::HasMorePages { state } => statement.paging_state.clone_from(state), @@ -258,7 +260,7 @@ pub unsafe extern "C" fn cass_statement_set_paging_state_token( paging_state: *const c_char, paging_state_size: size_t, ) -> CassError { - let statement_from_raw = ptr_to_ref_mut(statement); + let statement_from_raw = BoxFFI::as_mut_ref(statement); if paging_state.is_null() { statement_from_raw.paging_state = PagingState::start(); @@ -276,7 +278,7 @@ pub unsafe extern "C" fn cass_statement_set_is_idempotent( statement_raw: *mut CassStatement, is_idempotent: cass_bool_t, ) -> CassError { - match &mut ptr_to_ref_mut(statement_raw).statement { + match &mut BoxFFI::as_mut_ref(statement_raw).statement { Statement::Simple(inner) => inner.query.set_is_idempotent(is_idempotent != 0), Statement::Prepared(inner) => Arc::make_mut(inner) .statement @@ -291,7 +293,7 @@ pub unsafe extern "C" fn cass_statement_set_tracing( statement_raw: *mut CassStatement, enabled: cass_bool_t, ) -> CassError { - match &mut ptr_to_ref_mut(statement_raw).statement { + match &mut BoxFFI::as_mut_ref(statement_raw).statement { Statement::Simple(inner) => inner.query.set_tracing(enabled != 0), Statement::Prepared(inner) => Arc::make_mut(inner).statement.set_tracing(enabled != 0), } @@ -305,7 +307,7 @@ pub unsafe extern "C" fn cass_statement_set_retry_policy( retry_policy: *const CassRetryPolicy, ) -> CassError { let maybe_arced_retry_policy: Option> = - retry_policy.as_ref().map(|policy| match policy { + ArcFFI::as_maybe_ref(retry_policy).map(|policy| match policy { CassRetryPolicy::DefaultRetryPolicy(default) => { default.clone() as Arc } @@ -313,7 +315,7 @@ pub unsafe extern "C" fn cass_statement_set_retry_policy( CassRetryPolicy::DowngradingConsistencyRetryPolicy(downgrading) => downgrading.clone(), }); - match &mut ptr_to_ref_mut(statement).statement { + match &mut BoxFFI::as_mut_ref(statement).statement { Statement::Simple(inner) => inner.query.set_retry_policy(maybe_arced_retry_policy), Statement::Prepared(inner) => Arc::make_mut(inner) .statement @@ -342,7 +344,7 @@ pub unsafe extern "C" fn cass_statement_set_serial_consistency( _ => return CassError::CASS_ERROR_LIB_BAD_PARAMS, }; - match &mut ptr_to_ref_mut(statement).statement { + match &mut BoxFFI::as_mut_ref(statement).statement { Statement::Simple(inner) => inner.query.set_serial_consistency(Some(consistency)), Statement::Prepared(inner) => Arc::make_mut(inner) .statement @@ -374,7 +376,7 @@ pub unsafe extern "C" fn cass_statement_set_timestamp( statement: *mut CassStatement, timestamp: cass_int64_t, ) -> CassError { - match &mut ptr_to_ref_mut(statement).statement { + match &mut BoxFFI::as_mut_ref(statement).statement { Statement::Simple(inner) => inner.query.set_timestamp(Some(timestamp)), Statement::Prepared(inner) => Arc::make_mut(inner) .statement @@ -397,7 +399,7 @@ pub unsafe extern "C" fn cass_statement_set_request_timeout( return CassError::CASS_ERROR_LIB_BAD_PARAMS; } - let statement_from_raw = ptr_to_ref_mut(statement); + let statement_from_raw = BoxFFI::as_mut_ref(statement); statement_from_raw.request_timeout_ms = Some(timeout_ms); CassError::CASS_OK @@ -408,7 +410,7 @@ pub unsafe extern "C" fn cass_statement_reset_parameters( statement_raw: *mut CassStatement, count: size_t, ) -> CassError { - let statement = ptr_to_ref_mut(statement_raw); + let statement = BoxFFI::as_mut_ref(statement_raw); statement.reset_bound_values(count as usize); CassError::CASS_OK diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index fd8d5fa4..93602c63 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -15,6 +15,8 @@ pub struct CassTuple { pub items: Vec>, } +impl BoxFFI for CassTuple {} + impl CassTuple { fn get_types(&self) -> Option<&Vec>> { match &self.data_type { @@ -60,7 +62,7 @@ impl From<&CassTuple> for CassCqlValue { #[no_mangle] pub unsafe extern "C" fn cass_tuple_new(item_count: size_t) -> *mut CassTuple { - Box::into_raw(Box::new(CassTuple { + BoxFFI::into_ptr(Box::new(CassTuple { data_type: None, items: vec![None; item_count as usize], })) @@ -70,12 +72,12 @@ pub unsafe extern "C" fn cass_tuple_new(item_count: size_t) -> *mut CassTuple { unsafe extern "C" fn cass_tuple_new_from_data_type( data_type: *const CassDataType, ) -> *mut CassTuple { - let data_type = clone_arced(data_type); + let data_type = ArcFFI::cloned_from_ptr(data_type); let item_count = match data_type.get_unchecked() { CassDataTypeInner::Tuple(v) => v.len(), _ => return std::ptr::null_mut(), }; - Box::into_raw(Box::new(CassTuple { + BoxFFI::into_ptr(Box::new(CassTuple { data_type: Some(data_type), items: vec![None; item_count], })) @@ -83,13 +85,13 @@ unsafe extern "C" fn cass_tuple_new_from_data_type( #[no_mangle] unsafe extern "C" fn cass_tuple_free(tuple: *mut CassTuple) { - free_boxed(tuple) + BoxFFI::free(tuple); } #[no_mangle] unsafe extern "C" fn cass_tuple_data_type(tuple: *const CassTuple) -> *const CassDataType { - match &ptr_to_ref(tuple).data_type { - Some(t) => Arc::as_ptr(t), + match &BoxFFI::as_ref(tuple).data_type { + Some(t) => ArcFFI::as_ptr(t), None => &UNTYPED_TUPLE_TYPE, } } diff --git a/scylla-rust-wrapper/src/user_type.rs b/scylla-rust-wrapper/src/user_type.rs index 1775ace7..9335bfac 100644 --- a/scylla-rust-wrapper/src/user_type.rs +++ b/scylla-rust-wrapper/src/user_type.rs @@ -14,6 +14,8 @@ pub struct CassUserType { pub field_values: Vec>, } +impl BoxFFI for CassUserType {} + impl CassUserType { fn set_field_by_index(&mut self, index: usize, value: Option) -> CassError { if index >= self.field_values.len() { @@ -83,12 +85,12 @@ impl From<&CassUserType> for CassCqlValue { pub unsafe extern "C" fn cass_user_type_new_from_data_type( data_type_raw: *const CassDataType, ) -> *mut CassUserType { - let data_type = clone_arced(data_type_raw); + let data_type = ArcFFI::cloned_from_ptr(data_type_raw); match data_type.get_unchecked() { CassDataTypeInner::UDT(udt_data_type) => { let field_values = vec![None; udt_data_type.field_types.len()]; - Box::into_raw(Box::new(CassUserType { + BoxFFI::into_ptr(Box::new(CassUserType { data_type, field_values, })) @@ -99,13 +101,13 @@ pub unsafe extern "C" fn cass_user_type_new_from_data_type( #[no_mangle] pub unsafe extern "C" fn cass_user_type_free(user_type: *mut CassUserType) { - free_boxed(user_type); + BoxFFI::free(user_type); } #[no_mangle] pub unsafe extern "C" fn cass_user_type_data_type( user_type: *const CassUserType, ) -> *const CassDataType { - Arc::as_ptr(&ptr_to_ref(user_type).data_type) + ArcFFI::as_ptr(&BoxFFI::as_ref(user_type).data_type) } prepare_binders_macro!(@index_and_name CassUserType, diff --git a/scylla-rust-wrapper/src/uuid.rs b/scylla-rust-wrapper/src/uuid.rs index 52374dcc..b2343803 100644 --- a/scylla-rust-wrapper/src/uuid.rs +++ b/scylla-rust-wrapper/src/uuid.rs @@ -17,6 +17,8 @@ pub struct CassUuidGen { pub last_timestamp: AtomicU64, } +impl BoxFFI for CassUuidGen {} + // Implementation directly ported from Cpp Driver implementation: const TIME_OFFSET_BETWEEN_UTC_AND_EPOCH: u64 = 0x01B21DD213814000; // Nanoseconds @@ -113,7 +115,7 @@ pub unsafe extern "C" fn cass_uuid_gen_new() -> *mut CassUuidGen { // Masking the same way as in Cpp Driver. let node: u64 = (hasher.finish() & 0x0000FFFFFFFFFFFF) | 0x0000010000000000 /* Multicast bit */; - Box::into_raw(Box::new(CassUuidGen { + BoxFFI::into_ptr(Box::new(CassUuidGen { clock_seq_and_node: rand_clock_seq_and_node(node), last_timestamp: AtomicU64::new(0), })) @@ -121,7 +123,7 @@ pub unsafe extern "C" fn cass_uuid_gen_new() -> *mut CassUuidGen { #[no_mangle] pub unsafe extern "C" fn cass_uuid_gen_new_with_node(node: cass_uint64_t) -> *mut CassUuidGen { - Box::into_raw(Box::new(CassUuidGen { + BoxFFI::into_ptr(Box::new(CassUuidGen { clock_seq_and_node: rand_clock_seq_and_node(node & 0x0000FFFFFFFFFFFF), last_timestamp: AtomicU64::new(0), })) @@ -129,7 +131,7 @@ pub unsafe extern "C" fn cass_uuid_gen_new_with_node(node: cass_uint64_t) -> *mu #[no_mangle] pub unsafe extern "C" fn cass_uuid_gen_time(uuid_gen: *mut CassUuidGen, output: *mut CassUuid) { - let uuid_gen = ptr_to_ref_mut(uuid_gen); + let uuid_gen = BoxFFI::as_mut_ref(uuid_gen); let uuid = CassUuid { time_and_version: set_version(monotonic_timestamp(&mut uuid_gen.last_timestamp), 1), @@ -159,7 +161,7 @@ pub unsafe extern "C" fn cass_uuid_gen_from_time( timestamp: cass_uint64_t, output: *mut CassUuid, ) { - let uuid_gen = ptr_to_ref_mut(uuid_gen); + let uuid_gen = BoxFFI::as_mut_ref(uuid_gen); let uuid = CassUuid { time_and_version: set_version(from_unix_timestamp(timestamp), 1), @@ -249,5 +251,5 @@ pub unsafe extern "C" fn cass_uuid_from_string_n( #[no_mangle] pub unsafe extern "C" fn cass_uuid_gen_free(uuid_gen: *mut CassUuidGen) { - free_boxed(uuid_gen); + BoxFFI::free(uuid_gen); } From dd80a17a7d74f80441c071c6cf28a143ee10897d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20Bary=C5=82a?= Date: Fri, 6 Jan 2023 12:14:50 +0100 Subject: [PATCH 5/5] clippy: warn on FFI not using new traits --- scylla-rust-wrapper/clippy.toml | 21 +++++++++++++++++++++ scylla-rust-wrapper/src/argconv.rs | 13 +++++++++++++ scylla-rust-wrapper/src/future.rs | 1 + 3 files changed, 35 insertions(+) create mode 100644 scylla-rust-wrapper/clippy.toml diff --git a/scylla-rust-wrapper/clippy.toml b/scylla-rust-wrapper/clippy.toml new file mode 100644 index 00000000..7256d6f6 --- /dev/null +++ b/scylla-rust-wrapper/clippy.toml @@ -0,0 +1,21 @@ +disallowed-methods = [ + "std::boxed::Box::from_raw", + "std::boxed::Box::from_raw_in", + "std::boxed::Box::into_raw", + "std::boxed::Box::into_raw_with_allocator", + + "std::sync::Arc::as_ptr", + "std::sync::Arc::decrement_strong_count", + "std::sync::Arc::from_raw", + "std::sync::Arc::increment_strong_count", + "std::sync::Arc::into_raw", + + "std::rc::Rc::as_ptr", + "std::rc::Rc::decrement_strong_count", + "std::rc::Rc::from_raw", + "std::rc::Rc::increment_strong_count", + "std::rc::Rc::into_raw", + + "const_ptr::as_ref", + "mut_ptr::as_mut" +] diff --git a/scylla-rust-wrapper/src/argconv.rs b/scylla-rust-wrapper/src/argconv.rs index 953e0ca2..058ea45e 100644 --- a/scylla-rust-wrapper/src/argconv.rs +++ b/scylla-rust-wrapper/src/argconv.rs @@ -78,18 +78,23 @@ pub(crate) use make_c_str; /// the memory associated with the pointer using corresponding driver's API function. pub trait BoxFFI { fn into_ptr(self: Box) -> *mut Self { + #[allow(clippy::disallowed_methods)] Box::into_raw(self) } unsafe fn from_ptr(ptr: *mut Self) -> Box { + #[allow(clippy::disallowed_methods)] Box::from_raw(ptr) } unsafe fn as_maybe_ref<'a>(ptr: *const Self) -> Option<&'a Self> { + #[allow(clippy::disallowed_methods)] ptr.as_ref() } unsafe fn as_ref<'a>(ptr: *const Self) -> &'a Self { + #[allow(clippy::disallowed_methods)] ptr.as_ref().unwrap() } unsafe fn as_mut_ref<'a>(ptr: *mut Self) -> &'a mut Self { + #[allow(clippy::disallowed_methods)] ptr.as_mut().unwrap() } unsafe fn free(ptr: *mut Self) { @@ -105,22 +110,29 @@ pub trait BoxFFI { /// with the pointer using corresponding driver's API function. pub trait ArcFFI { fn as_ptr(self: &Arc) -> *const Self { + #[allow(clippy::disallowed_methods)] Arc::as_ptr(self) } fn into_ptr(self: Arc) -> *const Self { + #[allow(clippy::disallowed_methods)] Arc::into_raw(self) } unsafe fn from_ptr(ptr: *const Self) -> Arc { + #[allow(clippy::disallowed_methods)] Arc::from_raw(ptr) } unsafe fn cloned_from_ptr(ptr: *const Self) -> Arc { + #[allow(clippy::disallowed_methods)] Arc::increment_strong_count(ptr); + #[allow(clippy::disallowed_methods)] Arc::from_raw(ptr) } unsafe fn as_maybe_ref<'a>(ptr: *const Self) -> Option<&'a Self> { + #[allow(clippy::disallowed_methods)] ptr.as_ref() } unsafe fn as_ref<'a>(ptr: *const Self) -> &'a Self { + #[allow(clippy::disallowed_methods)] ptr.as_ref().unwrap() } unsafe fn free(ptr: *const Self) { @@ -141,6 +153,7 @@ pub trait RefFFI { self as *const Self } unsafe fn as_ref<'a>(ptr: *const Self) -> &'a Self { + #[allow(clippy::disallowed_methods)] ptr.as_ref().unwrap() } } diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index 4fa84c3a..a3c46f52 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -489,6 +489,7 @@ mod tests { // the future, and execute its callback #[test] #[ntest::timeout(600)] + #[allow(clippy::disallowed_methods)] fn test_cass_future_callback() { unsafe { const ERROR_MSG: &str = "NOBODY EXPECTED SPANISH INQUISITION";