diff --git a/scylla-rust-wrapper/Cargo.lock b/scylla-rust-wrapper/Cargo.lock index a063335d..37fed39e 100644 --- a/scylla-rust-wrapper/Cargo.lock +++ b/scylla-rust-wrapper/Cargo.lock @@ -510,9 +510,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.11.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] @@ -1027,9 +1027,8 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "scylla" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b20b46cf4ea921ba41121ba9ddf933185cd830cbe2c4fa6272a6e274a6b7368d" +version = "0.14.0" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v0.14.0#0341198937871ea135fdca3ecfcea9a792ee9b18" dependencies = [ "arc-swap", "async-trait", @@ -1085,9 +1084,8 @@ dependencies = [ [[package]] name = "scylla-cql" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27ea3cd3ff5bf9d7db7a6d65c54cecf52f7c40b8e3e32c8c2d6da84d23776ea4" +version = "0.3.0" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v0.14.0#0341198937871ea135fdca3ecfcea9a792ee9b18" dependencies = [ "async-trait", "byteorder", @@ -1102,9 +1100,8 @@ dependencies = [ [[package]] name = "scylla-macros" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e50f3e2aec7ea9f495e029fb783eb34c64d26a8f2055e1d6b43d00e04d2fbda6" +version = "0.6.0" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v0.14.0#0341198937871ea135fdca3ecfcea9a792ee9b18" dependencies = [ "darling", "proc-macro2", @@ -1114,9 +1111,8 @@ dependencies = [ [[package]] name = "scylla-proxy" -version = "0.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c5b3907e01611b2c514fc7a5563be863814ed9f0034e7080a86515232c20433" +version = "0.0.3" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v0.14.0#0341198937871ea135fdca3ecfcea9a792ee9b18" dependencies = [ "bigdecimal", "byteorder", diff --git a/scylla-rust-wrapper/Cargo.toml b/scylla-rust-wrapper/Cargo.toml index b072d706..bc171a80 100644 --- a/scylla-rust-wrapper/Cargo.toml +++ b/scylla-rust-wrapper/Cargo.toml @@ -10,7 +10,9 @@ categories = ["database"] license = "MIT OR Apache-2.0" [dependencies] -scylla = { version = "0.13.1", features = ["ssl"] } +scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v0.14.0", features = [ + "ssl", +] } tokio = { version = "1.27.0", features = ["full"] } lazy_static = "1.4.0" uuid = "1.1.2" @@ -29,11 +31,11 @@ bindgen = "0.65" chrono = "0.4.20" [dev-dependencies] +scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v0.14.0" } + assert_matches = "1.5.0" ntest = "0.9.3" rusty-fork = "0.3.0" -scylla-proxy = { version = "0.0.4" } - [lib] name = "scylla_cpp_driver" crate-type = ["cdylib", "staticlib"] diff --git a/scylla-rust-wrapper/src/cass_error.rs b/scylla-rust-wrapper/src/cass_error.rs index 5ed8151b..3576e7d4 100644 --- a/scylla-rust-wrapper/src/cass_error.rs +++ b/scylla-rust-wrapper/src/cass_error.rs @@ -15,6 +15,7 @@ impl From<&QueryError> for CassError { QueryError::UnableToAllocStreamId => CassError::CASS_ERROR_LIB_NO_STREAMS, QueryError::RequestTimeout(_) => CassError::CASS_ERROR_LIB_REQUEST_TIMED_OUT, QueryError::TranslationError(_) => CassError::CASS_ERROR_LIB_HOST_RESOLUTION, + QueryError::CqlResponseParseError(_) => CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE, } } } @@ -83,6 +84,9 @@ impl From<&NewSessionError> for CassError { NewSessionError::UnableToAllocStreamId => CassError::CASS_ERROR_LAST_ENTRY, NewSessionError::RequestTimeout(_) => CassError::CASS_ERROR_LIB_REQUEST_TIMED_OUT, NewSessionError::TranslationError(_) => CassError::CASS_ERROR_LIB_HOST_RESOLUTION, + NewSessionError::CqlResponseParseError(_) => { + CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE + } } } } diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/prepared.rs index 1c4721fb..5094d652 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/prepared.rs @@ -1,4 +1,4 @@ -use scylla::frame::value::MaybeUnset::Unset; +use scylla::{frame::value::MaybeUnset::Unset, transport::PagingState}; use std::sync::Arc; use crate::{ @@ -28,7 +28,9 @@ pub unsafe extern "C" fn cass_prepared_bind( Box::into_raw(Box::new(CassStatement { statement, bound_values: vec![Unset; bound_values_size], - paging_state: None, + paging_state: PagingState::start(), + // Cpp driver disables paging by default. + paging_enabled: false, request_timeout_ms: None, exec_profile: None, })) diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 716b8101..894cbd85 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -8,7 +8,7 @@ use crate::metadata::{ use crate::types::*; use crate::uuid::CassUuid; use scylla::frame::response::result::{ColumnSpec, CqlValue}; -use scylla::Bytes; +use scylla::transport::PagingStateResponse; use std::convert::TryInto; use std::os::raw::c_char; use std::sync::Arc; @@ -20,7 +20,7 @@ pub struct CassResult { } pub struct CassResultData { - pub paging_state: Option, + pub paging_state_response: PagingStateResponse, pub col_specs: Vec, pub tracing_id: Option, } @@ -815,7 +815,7 @@ pub unsafe extern "C" fn cass_result_free(result_raw: *const CassResult) { #[no_mangle] pub unsafe extern "C" fn cass_result_has_more_pages(result: *const CassResult) -> cass_bool_t { let result = ptr_to_ref(result); - result.metadata.paging_state.is_some() as cass_bool_t + (!result.metadata.paging_state_response.finished()) as cass_bool_t } #[no_mangle] @@ -1298,12 +1298,18 @@ pub unsafe extern "C" fn cass_result_paging_state_token( let result_from_raw = ptr_to_ref(result); - match &result_from_raw.metadata.paging_state { - Some(result_paging_state) => { - *paging_state_size = result_paging_state.len() as u64; - *paging_state = result_paging_state.as_ptr() as *const c_char; - } - None => { + match &result_from_raw.metadata.paging_state_response { + PagingStateResponse::HasMorePages { state } => match state.as_bytes_slice() { + Some(result_paging_state) => { + *paging_state_size = result_paging_state.len() as u64; + *paging_state = result_paging_state.as_ptr() as *const c_char; + } + None => { + *paging_state_size = 0; + *paging_state = std::ptr::null(); + } + }, + PagingStateResponse::NoMorePages => { *paging_state_size = 0; *paging_state = std::ptr::null(); } diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 8ffef2af..2060308f 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -18,6 +18,7 @@ use scylla::frame::types::Consistency; use scylla::query::Query; use scylla::transport::errors::QueryError; use scylla::transport::execution_profile::ExecutionProfileHandle; +use scylla::transport::PagingStateResponse; use scylla::{QueryResult, Session, SessionBuilder}; use std::collections::HashMap; use std::future::Future; @@ -205,7 +206,7 @@ pub unsafe extern "C" fn cass_session_execute_batch( Ok(_result) => Ok(CassResultValue::QueryResult(Arc::new(CassResult { rows: None, metadata: Arc::new(CassResultData { - paging_state: None, + paging_state_response: PagingStateResponse::NoMorePages, col_specs: vec![], tracing_id: None, }), @@ -244,6 +245,7 @@ pub unsafe extern "C" fn cass_session_execute( // DO NOT refer to `statement_opt` inside the async block, as I've done just to face a segfault. let statement_opt = ptr_to_ref(statement_raw); let paging_state = statement_opt.paging_state.clone(); + let paging_enabled = statement_opt.paging_enabled; let bound_values = statement_opt.bound_values.clone(); let request_timeout_ms = statement_opt.request_timeout_ms; @@ -274,24 +276,38 @@ pub unsafe extern "C" fn cass_session_execute( } } - let query_res: Result = match statement { + let query_res: Result<(QueryResult, PagingStateResponse), QueryError> = match statement { Statement::Simple(query) => { - session - .query_paged(query.query, bound_values, paging_state) - .await + if paging_enabled { + session + .query_single_page(query.query, bound_values, paging_state) + .await + } else { + session + .query_unpaged(query.query, bound_values) + .await + .map(|result| (result, PagingStateResponse::NoMorePages)) + } } Statement::Prepared(prepared) => { - session - .execute_paged(&prepared, bound_values, paging_state) - .await + if paging_enabled { + session + .execute_single_page(&prepared, bound_values, paging_state) + .await + } else { + session + .execute_unpaged(&prepared, bound_values) + .await + .map(|result| (result, PagingStateResponse::NoMorePages)) + } } }; match query_res { - Ok(result) => { + Ok((result, paging_state_response)) => { let metadata = Arc::new(CassResultData { - paging_state: result.paging_state, - col_specs: result.col_specs, + paging_state_response, + col_specs: result.col_specs().to_vec(), tracing_id: result.tracing_id, }); let cass_rows = create_cass_rows_from_rows(result.rows, &metadata); @@ -516,7 +532,6 @@ pub unsafe extern "C" fn cass_session_prepare_n( .map_err(|err| (CassError::from(&err), err.msg()))?; // Set Cpp Driver default configuration for queries: - prepared.disable_paging(); prepared.set_consistency(Consistency::One); Ok(CassResultValue::Prepared(Arc::new(prepared))) diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index ad65b4e0..84d2375e 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -11,7 +11,7 @@ use scylla::frame::value::MaybeUnset::{Set, Unset}; use scylla::query::Query; use scylla::statement::prepared_statement::PreparedStatement; use scylla::statement::SerialConsistency; -use scylla::{BufMut, Bytes, BytesMut}; +use scylla::transport::{PagingState, PagingStateResponse}; use std::collections::HashMap; use std::convert::TryInto; use std::os::raw::{c_char, c_int}; @@ -36,7 +36,8 @@ pub struct SimpleQuery { pub struct CassStatement { pub statement: Statement, pub bound_values: Vec>>, - pub paging_state: Option, + pub paging_state: PagingState, + pub paging_enabled: bool, pub request_timeout_ms: Option, pub(crate) exec_profile: Option, @@ -145,10 +146,7 @@ pub unsafe extern "C" fn cass_statement_new_n( None => return std::ptr::null_mut(), }; - let mut query = Query::new(query_str.to_string()); - - // Set Cpp Driver default configuration for queries: - query.disable_paging(); + let query = Query::new(query_str.to_string()); let simple_query = SimpleQuery { query, @@ -158,7 +156,9 @@ pub unsafe extern "C" fn cass_statement_new_n( Box::into_raw(Box::new(CassStatement { statement: Statement::Simple(simple_query), bound_values: vec![Unset; parameter_count as usize], - paging_state: None, + paging_state: PagingState::start(), + // Cpp driver disables paging by default. + paging_enabled: false, request_timeout_ms: None, exec_profile: None, })) @@ -191,21 +191,15 @@ pub unsafe extern "C" fn cass_statement_set_paging_size( statement_raw: *mut CassStatement, page_size: c_int, ) -> CassError { - // TODO: validate page_size - match &mut ptr_to_ref_mut(statement_raw).statement { - Statement::Simple(inner) => { - if page_size == -1 { - inner.query.disable_paging() - } else { - inner.query.set_page_size(page_size) - } - } - Statement::Prepared(inner) => { - if page_size == -1 { - Arc::make_mut(inner).disable_paging() - } else { - Arc::make_mut(inner).set_page_size(page_size) - } + let statement = ptr_to_ref_mut(statement_raw); + if page_size <= 0 { + // Cpp driver sets the page size flag only for positive page size provided by user. + statement.paging_enabled = false; + } else { + statement.paging_enabled = true; + match &mut statement.statement { + Statement::Simple(inner) => inner.query.set_page_size(page_size), + Statement::Prepared(inner) => Arc::make_mut(inner).set_page_size(page_size), } } @@ -220,9 +214,10 @@ pub unsafe extern "C" fn cass_statement_set_paging_state( let statement = ptr_to_ref_mut(statement); let result = ptr_to_ref(result); - statement - .paging_state - .clone_from(&result.metadata.paging_state); + match &result.metadata.paging_state_response { + PagingStateResponse::HasMorePages { state } => statement.paging_state.clone_from(state), + PagingStateResponse::NoMorePages => statement.paging_state = PagingState::start(), + } CassError::CASS_OK } @@ -235,18 +230,13 @@ pub unsafe extern "C" fn cass_statement_set_paging_state_token( let statement_from_raw = ptr_to_ref_mut(statement); if paging_state.is_null() { - statement_from_raw.paging_state = None; + statement_from_raw.paging_state = PagingState::start(); return CassError::CASS_ERROR_LIB_NULL_VALUE; } let paging_state_usize: usize = paging_state_size.try_into().unwrap(); - let mut b = BytesMut::with_capacity(paging_state_usize); - let paging_state_bytes = slice::from_raw_parts(paging_state, paging_state_usize); - for byte in paging_state_bytes { - b.put_i8(*byte); - } - statement_from_raw.paging_state = Some(b.freeze()); - + let paging_state_bytes = slice::from_raw_parts(paging_state as *const u8, paging_state_usize); + statement_from_raw.paging_state = PagingState::new_from_raw_bytes(paging_state_bytes); CassError::CASS_OK } diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 397587de..ffd02feb 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -7,7 +7,7 @@ use scylla::{ }, serialize::{ value::{ - BuiltinSerializationErrorKind, MapSerializationErrorKind, SerializeCql, + BuiltinSerializationErrorKind, MapSerializationErrorKind, SerializeValue, SetOrListSerializationErrorKind, TupleSerializationErrorKind, UdtSerializationErrorKind, }, @@ -60,7 +60,7 @@ pub enum CassCqlValue { // TODO: custom (?), duration and decimal } -impl SerializeCql for CassCqlValue { +impl SerializeValue for CassCqlValue { fn serialize<'b>( &self, _typ: &ColumnType, @@ -89,44 +89,44 @@ impl CassCqlValue { // will never fail. Thanks to that, we do not have to reimplement low-level serialization // for each type. CassCqlValue::TinyInt(v) => { - ::serialize(v, &ColumnType::TinyInt, writer) + ::serialize(v, &ColumnType::TinyInt, writer) } CassCqlValue::SmallInt(v) => { - ::serialize(v, &ColumnType::SmallInt, writer) + ::serialize(v, &ColumnType::SmallInt, writer) } - CassCqlValue::Int(v) => ::serialize(v, &ColumnType::Int, writer), + CassCqlValue::Int(v) => ::serialize(v, &ColumnType::Int, writer), CassCqlValue::BigInt(v) => { - ::serialize(v, &ColumnType::BigInt, writer) + ::serialize(v, &ColumnType::BigInt, writer) } CassCqlValue::Float(v) => { - ::serialize(v, &ColumnType::Float, writer) + ::serialize(v, &ColumnType::Float, writer) } CassCqlValue::Double(v) => { - ::serialize(v, &ColumnType::Double, writer) + ::serialize(v, &ColumnType::Double, writer) } CassCqlValue::Boolean(v) => { - ::serialize(v, &ColumnType::Boolean, writer) + ::serialize(v, &ColumnType::Boolean, writer) } CassCqlValue::Text(v) => { - ::serialize(v, &ColumnType::Text, writer) + ::serialize(v, &ColumnType::Text, writer) } CassCqlValue::Blob(v) => { - as SerializeCql>::serialize(v, &ColumnType::Blob, writer) + as SerializeValue>::serialize(v, &ColumnType::Blob, writer) } CassCqlValue::Uuid(v) => { - ::serialize(v, &ColumnType::Uuid, writer) + ::serialize(v, &ColumnType::Uuid, writer) } CassCqlValue::Date(v) => { - ::serialize(v, &ColumnType::Date, writer) + ::serialize(v, &ColumnType::Date, writer) } CassCqlValue::Inet(v) => { - ::serialize(v, &ColumnType::Inet, writer) + ::serialize(v, &ColumnType::Inet, writer) } CassCqlValue::Duration(v) => { - ::serialize(v, &ColumnType::Duration, writer) + ::serialize(v, &ColumnType::Duration, writer) } CassCqlValue::Decimal(v) => { - ::serialize(v, &ColumnType::Decimal, writer) + ::serialize(v, &ColumnType::Decimal, writer) } CassCqlValue::Tuple(fields) => serialize_tuple_like(fields.iter(), writer), CassCqlValue::List(l) => serialize_sequence(l.len(), l.iter(), writer),