diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index f75540ae..66294e74 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -15,7 +15,7 @@ include!(concat!(env!("OUT_DIR"), "/cppdriver_data_types.rs")); include!(concat!(env!("OUT_DIR"), "/cppdriver_data_query_error.rs")); include!(concat!(env!("OUT_DIR"), "/cppdriver_batch_types.rs")); -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct UDTDataType { // Vec to preserve the order of types pub field_types: Vec<(String, Arc)>, @@ -131,14 +131,14 @@ impl Default for UDTDataType { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum MapDataType { Untyped, Key(Arc), KeyAndValue(Arc, Arc), } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum CassDataType { Value(CassValueType), UDT(UDTDataType), diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/prepared.rs index 33fbcba6..aa29186d 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/prepared.rs @@ -12,6 +12,11 @@ use scylla::prepared_statement::PreparedStatement; pub struct CassPrepared { // Data types of columns from PreparedMetadata. pub variable_col_data_types: Vec>, + // Data types of columns from ResultMetadata. + // + // Arc -> to share each data type with other structs such as `CassValue` + // Arc> -> to share the whole vector with `CassResultData`. + pub result_col_data_types: Arc>>, pub statement: PreparedStatement, } @@ -23,8 +28,17 @@ impl CassPrepared { .map(|col_spec| Arc::new(get_column_type(&col_spec.typ))) .collect(); + let result_col_data_types: Arc>> = Arc::new( + statement + .get_result_set_col_specs() + .iter() + .map(|col_spec| Arc::new(get_column_type(&col_spec.typ))) + .collect(), + ); + Self { variable_col_data_types, + result_col_data_types, statement, } } diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 4a98b991..d77b6b94 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -1,6 +1,8 @@ use crate::argconv::*; use crate::cass_error::CassError; -use crate::cass_types::{cass_data_type_type, CassDataType, CassValueType, MapDataType}; +use crate::cass_types::{ + cass_data_type_type, get_column_type, CassDataType, CassValueType, MapDataType, +}; use crate::inet::CassInet; use crate::metadata::{ CassColumnMeta, CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta, CassTableMeta, @@ -22,9 +24,39 @@ pub struct CassResult { pub struct CassResultData { pub paging_state_response: PagingStateResponse, pub col_specs: Vec, + pub col_data_types: Arc>>, pub tracing_id: Option, } +impl CassResultData { + pub fn from_result_payload( + paging_state_response: PagingStateResponse, + col_specs: Vec, + maybe_col_data_types: Option>>>, + tracing_id: Option, + ) -> CassResultData { + // `maybe_col_data_types` is: + // - Some(_) for prepared statements executions + // - None for unprepared (simple) queries executions + let col_data_types = maybe_col_data_types.unwrap_or_else(|| { + // This allocation is unfortunately necessary, because of the type of CassResultData::col_data_types. + Arc::new( + col_specs + .iter() + .map(|col_spec| Arc::new(get_column_type(&col_spec.typ))) + .collect(), + ) + }); + + CassResultData { + paging_state_response, + col_specs, + col_data_types, + tracing_id, + } + } +} + /// The lifetime of CassRow is bound to CassResult. /// It will be freed, when CassResult is freed.(see #[cass_result_free]) pub struct CassRow { @@ -905,6 +937,36 @@ pub unsafe extern "C" fn cass_result_column_name( CassError::CASS_OK } +#[no_mangle] +pub unsafe extern "C" fn cass_result_column_type( + result: *const CassResult, + index: size_t, +) -> CassValueType { + let data_type_ptr = cass_result_column_data_type(result, index); + if data_type_ptr.is_null() { + return CassValueType::CASS_VALUE_TYPE_UNKNOWN; + } + cass_data_type_type(data_type_ptr) +} + +#[no_mangle] +pub unsafe extern "C" fn cass_result_column_data_type( + result: *const CassResult, + index: size_t, +) -> *const CassDataType { + let result_from_raw: &CassResult = ptr_to_ref(result); + let index_usize: usize = index + .try_into() + .expect("Provided index is out of bounds. Max possible value is usize::MAX"); + + result_from_raw + .metadata + .col_data_types + .get(index_usize) + .map(Arc::as_ptr) + .unwrap_or(std::ptr::null()) +} + #[no_mangle] pub unsafe extern "C" fn cass_value_type(value: *const CassValue) -> CassValueType { let value_from_raw = ptr_to_ref(value); @@ -1283,11 +1345,12 @@ pub unsafe extern "C" fn cass_result_column_count(result_raw: *const CassResult) pub unsafe extern "C" fn cass_result_first_row(result_raw: *const CassResult) -> *const CassRow { let result = ptr_to_ref(result_raw); - if result.rows.is_some() || result.rows.as_ref().unwrap().is_empty() { - return result.rows.as_ref().unwrap().first().unwrap(); - } - - std::ptr::null() + result + .rows + .as_ref() + .and_then(|rows| rows.first()) + .map(|row| row as *const CassRow) + .unwrap_or(std::ptr::null()) } #[no_mangle] @@ -1322,6 +1385,200 @@ pub unsafe extern "C" fn cass_result_paging_state_token( CassError::CASS_OK } +#[cfg(test)] +mod tests { + use std::{ffi::c_char, ptr::addr_of_mut, sync::Arc}; + + use scylla::{ + frame::response::result::{ColumnSpec, ColumnType, CqlValue, Row, TableSpec}, + transport::PagingStateResponse, + }; + + use crate::{ + cass_error::CassError, + cass_types::{CassDataType, CassValueType}, + query_result::{ + cass_result_column_data_type, cass_result_column_name, cass_result_first_row, + ptr_to_cstr_n, ptr_to_ref, size_t, + }, + session::create_cass_rows_from_rows, + }; + + use super::{cass_result_column_count, cass_result_column_type, CassResult, CassResultData}; + + fn col_spec(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + table_spec: TableSpec::borrowed("ks", "tbl"), + name: name.to_owned(), + typ, + } + } + + const FIRST_COLUMN_NAME: &str = "bigint_col"; + const SECOND_COLUMN_NAME: &str = "varint_col"; + const THIRD_COLUMN_NAME: &str = "list_double_col"; + fn create_cass_rows_result() -> CassResult { + let metadata = Arc::new(CassResultData::from_result_payload( + PagingStateResponse::NoMorePages, + vec![ + col_spec(FIRST_COLUMN_NAME, ColumnType::BigInt), + col_spec(SECOND_COLUMN_NAME, ColumnType::Varint), + col_spec( + THIRD_COLUMN_NAME, + ColumnType::List(Box::new(ColumnType::Double)), + ), + ], + None, + None, + )); + + let rows = create_cass_rows_from_rows( + vec![Row { + columns: vec![ + Some(CqlValue::BigInt(42)), + None, + Some(CqlValue::List(vec![ + CqlValue::Float(0.5), + CqlValue::Float(42.42), + CqlValue::Float(9999.9999), + ])), + ], + }], + &metadata, + ); + + CassResult { + rows: Some(rows), + metadata, + } + } + + unsafe fn cass_result_column_name_rust_str( + result_ptr: *const CassResult, + column_index: u64, + ) -> Option<&'static str> { + let mut name_ptr: *const c_char = std::ptr::null(); + let mut name_length: size_t = 0; + let cass_err = cass_result_column_name( + result_ptr, + column_index, + addr_of_mut!(name_ptr), + addr_of_mut!(name_length), + ); + assert_eq!(CassError::CASS_OK, cass_err); + ptr_to_cstr_n(name_ptr, name_length) + } + + #[test] + fn rows_cass_result_api_test() { + let result = create_cass_rows_result(); + + unsafe { + let result_ptr = std::ptr::addr_of!(result); + + // cass_result_column_count test + { + let column_count = cass_result_column_count(result_ptr); + assert_eq!(3, column_count); + } + + // cass_result_column_name test + { + let first_column_name = cass_result_column_name_rust_str(result_ptr, 0).unwrap(); + assert_eq!(FIRST_COLUMN_NAME, first_column_name); + let second_column_name = cass_result_column_name_rust_str(result_ptr, 1).unwrap(); + assert_eq!(SECOND_COLUMN_NAME, second_column_name); + let third_column_name = cass_result_column_name_rust_str(result_ptr, 2).unwrap(); + assert_eq!(THIRD_COLUMN_NAME, third_column_name); + } + + // cass_result_column_type test + { + let first_col_type = cass_result_column_type(result_ptr, 0); + assert_eq!(CassValueType::CASS_VALUE_TYPE_BIGINT, first_col_type); + let second_col_type = cass_result_column_type(result_ptr, 1); + assert_eq!(CassValueType::CASS_VALUE_TYPE_VARINT, second_col_type); + let third_col_type = cass_result_column_type(result_ptr, 2); + assert_eq!(CassValueType::CASS_VALUE_TYPE_LIST, third_col_type); + let out_of_bound_col_type = cass_result_column_type(result_ptr, 555); + assert_eq!( + CassValueType::CASS_VALUE_TYPE_UNKNOWN, + out_of_bound_col_type + ); + } + + // cass_result_column_data_type test + { + let first_col_data_type = ptr_to_ref(cass_result_column_data_type(result_ptr, 0)); + assert_eq!( + &CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BIGINT), + first_col_data_type + ); + let second_col_data_type = ptr_to_ref(cass_result_column_data_type(result_ptr, 1)); + assert_eq!( + &CassDataType::Value(CassValueType::CASS_VALUE_TYPE_VARINT), + second_col_data_type + ); + let third_col_data_type = ptr_to_ref(cass_result_column_data_type(result_ptr, 2)); + assert_eq!( + &CassDataType::List { + typ: Some(Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_DOUBLE + ))), + frozen: false + }, + third_col_data_type + ); + let out_of_bound_col_data_type = cass_result_column_data_type(result_ptr, 555); + assert!(out_of_bound_col_data_type.is_null()); + } + } + } + + fn create_non_rows_cass_result() -> CassResult { + let metadata = Arc::new(CassResultData::from_result_payload( + PagingStateResponse::NoMorePages, + vec![], + None, + None, + )); + CassResult { + rows: None, + metadata, + } + } + + #[test] + fn non_rows_cass_result_api_test() { + let result = create_non_rows_cass_result(); + + // Check that API functions do not panic when rows are empty - e.g. for INSERT queries. + unsafe { + let result_ptr = std::ptr::addr_of!(result); + + assert_eq!(0, cass_result_column_count(result_ptr)); + assert_eq!( + CassValueType::CASS_VALUE_TYPE_UNKNOWN, + cass_result_column_type(result_ptr, 0) + ); + assert!(cass_result_column_data_type(result_ptr, 0).is_null()); + assert!(cass_result_first_row(result_ptr).is_null()); + + { + let mut name_ptr: *const c_char = std::ptr::null(); + let mut name_length: size_t = 0; + let cass_err = cass_result_column_name( + result_ptr, + 0, + addr_of_mut!(name_ptr), + addr_of_mut!(name_length), + ); + assert_eq!(CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, cass_err); + } + } + } +} + // CassResult functions: /* extern "C" { diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 2f5f17e8..fe5f419f 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -1,7 +1,7 @@ use crate::argconv::*; use crate::batch::CassBatch; use crate::cass_error::*; -use crate::cass_types::{get_column_type, CassDataType, MapDataType, UDTDataType}; +use crate::cass_types::{CassDataType, MapDataType, UDTDataType}; use crate::cluster::build_session_builder; use crate::cluster::CassCluster; use crate::exec_profile::{CassExecProfile, ExecProfileName, PerStatementExecProfile}; @@ -214,11 +214,12 @@ pub unsafe extern "C" fn cass_session_execute_batch( match query_res { Ok(_result) => Ok(CassResultValue::QueryResult(Arc::new(CassResult { rows: None, - metadata: Arc::new(CassResultData { - paging_state_response: PagingStateResponse::NoMorePages, - col_specs: vec![], - tracing_id: None, - }), + metadata: Arc::new(CassResultData::from_result_payload( + PagingStateResponse::NoMorePages, + vec![], + None, + None, + )), }))), Err(err) => Ok(CassResultValue::QueryError(Arc::new(err))), } @@ -285,41 +286,80 @@ pub unsafe extern "C" fn cass_session_execute( .set_execution_profile_handle(handle), } - let query_res: Result<(QueryResult, PagingStateResponse), QueryError> = match statement { + // Creating a type alias here to fix clippy lints. + // I want this type to be explicit, so future developers can understand + // what's going on here (and why we include some weird Option of data types). + type QueryRes = Result< + ( + QueryResult, + PagingStateResponse, + // We unfortunately have to retrieve the metadata here. + // Since `query.query` is consumed, we cannot match the statement + // after execution, to retrieve the cached metadata in case + // of prepared statements. + Option>>>, + ), + QueryError, + >; + let query_res: QueryRes = match statement { Statement::Simple(query) => { + // We don't store result metadata for Queries - return None. + let maybe_result_col_data_types = None; + if paging_enabled { session .query_single_page(query.query, bound_values, paging_state) .await + .map(|(qr, psr)| (qr, psr, maybe_result_col_data_types)) } else { session .query_unpaged(query.query, bound_values) .await - .map(|result| (result, PagingStateResponse::NoMorePages)) + .map(|result| { + ( + result, + PagingStateResponse::NoMorePages, + maybe_result_col_data_types, + ) + }) } } Statement::Prepared(prepared) => { + // Clone vector of the Arc, so we don't do additional allocations when constructing + // CassDataTypes in `CassResultData::from_result_payload`. + let maybe_result_col_data_types = Some(prepared.result_col_data_types.clone()); + if paging_enabled { session .execute_single_page(&prepared.statement, bound_values, paging_state) .await + .map(|(qr, psr)| (qr, psr, maybe_result_col_data_types)) } else { session .execute_unpaged(&prepared.statement, bound_values) .await - .map(|result| (result, PagingStateResponse::NoMorePages)) + .map(|result| { + ( + result, + PagingStateResponse::NoMorePages, + maybe_result_col_data_types, + ) + }) } } }; match query_res { - Ok((result, paging_state_response)) => { - let metadata = Arc::new(CassResultData { + Ok((result, paging_state_response, maybe_col_data_types)) => { + let metadata = Arc::new(CassResultData::from_result_payload( paging_state_response, - col_specs: result.col_specs().to_vec(), - tracing_id: result.tracing_id, - }); - let cass_rows = create_cass_rows_from_rows(result.rows, &metadata); + result.col_specs().to_vec(), + maybe_col_data_types, + result.tracing_id, + )); + let cass_rows = result + .rows + .map(|rows| create_cass_rows_from_rows(rows, &metadata)); let cass_result = Arc::new(CassResult { rows: cass_rows, metadata, @@ -339,28 +379,24 @@ pub unsafe extern "C" fn cass_session_execute( } } -fn create_cass_rows_from_rows( - rows: Option>, +pub(crate) fn create_cass_rows_from_rows( + rows: Vec, metadata: &Arc, -) -> Option> { - let rows = rows?; - let cass_rows = rows - .into_iter() +) -> Vec { + rows.into_iter() .map(|r| CassRow { columns: create_cass_row_columns(r, metadata), result_metadata: metadata.clone(), }) - .collect(); - - Some(cass_rows) + .collect() } fn create_cass_row_columns(row: Row, metadata: &Arc) -> Vec { row.columns .into_iter() - .zip(metadata.col_specs.iter()) - .map(|(val, col)| { - let column_type = Arc::new(get_column_type(&col.typ)); + .zip(metadata.col_data_types.iter()) + .map(|(val, col_data_type)| { + let column_type = Arc::clone(col_data_type); CassValue { value: val.map(|col_val| get_column_value(col_val, &column_type)), value_type: column_type,