diff --git a/Cargo.toml b/Cargo.toml index dbe2fce..afd412c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,5 +23,3 @@ indicatif = "0.18" env_logger = "0.11" insta = "1.43.2" rustyline = { version = "14.0", features = ["derive"] } - - diff --git a/src/lib.rs b/src/lib.rs index da2a616..cbc3787 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,8 @@ mod variant_object_construct; mod variant_object_delete; mod variant_object_insert; mod variant_pretty; +mod variant_schema; +mod variant_schema_agg; mod variant_to_json; pub use cast_to_variant::*; @@ -26,4 +28,6 @@ pub use variant_object_construct::*; pub use variant_object_delete::*; pub use variant_object_insert::*; pub use variant_pretty::*; +pub use variant_schema::*; +pub use variant_schema_agg::*; pub use variant_to_json::*; diff --git a/src/variant_schema.rs b/src/variant_schema.rs new file mode 100644 index 0000000..e5b9463 --- /dev/null +++ b/src/variant_schema.rs @@ -0,0 +1,501 @@ +use arrow::array::{ArrayRef, StringViewArray}; +use arrow_schema::{DataType, TimeUnit}; +use datafusion::{ + common::exec_err, + error::{DataFusionError, Result}, + logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility}, + scalar::ScalarValue, +}; +use parquet_variant::Variant; +use parquet_variant_compute::VariantArray; +use std::collections::BTreeMap; +use std::collections::btree_map::Entry; +use std::sync::Arc; + +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct VariantSchemaUDF { + signature: Signature, +} + +impl Default for VariantSchemaUDF { + fn default() -> Self { + Self { + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), + } + } +} + +/// Infers a schema description for one VARIANT value. +/// +/// The inferred schema can be one of four logical forms: +/// - Primitive: a concrete SQL / Arrow data type +/// - Array: `ARRAY`, where `inner` is merged across elements in that array value +/// - Object: `OBJECT`, merged recursively per field +/// - Variant: fallback when no common inner schema can be determined +/// +/// Execution semantics: +/// - Scalar input: infer one schema string for that value. +/// - Columnar input: infer one schema string per row (vectorized row-wise behavior). +/// - This function does not merge schemas across rows. For cross-row/group merge use +/// `variant_schema_agg`. +/// +/// Merge rules (within one VARIANT value only): +/// - If outer (or inner) kinds differ, the result is `VARIANT` +/// - Primitive types are merged using widening / least-common-type rules +/// - Arrays merge by merging their element schemas +/// - Objects merge field-by-field; missing fields are allowed +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum VariantSchema { + Primitive(DataType), + Array(Box), + Object(BTreeMap), + Variant, +} + +impl VariantSchema { + pub fn to_state_bytes(&self) -> Vec { + let mut out = Vec::new(); + encode_variant_schema(self, &mut out); + out + } + + pub fn from_state_bytes(bytes: &[u8]) -> Result { + let mut offset = 0usize; + let decoded = decode_variant_schema(bytes, &mut offset)?; + if offset != bytes.len() { + return exec_err!("invalid variant_schema_agg state: trailing bytes"); + } + Ok(decoded) + } +} + +fn encode_len_prefixed_bytes(out: &mut Vec, bytes: &[u8]) { + out.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + out.extend_from_slice(bytes); +} + +fn read_u8(input: &[u8], offset: &mut usize) -> Result { + let Some(v) = input.get(*offset) else { + return exec_err!("invalid variant_schema_agg state: missing tag"); + }; + *offset += 1; + Ok(*v) +} + +fn read_u32(input: &[u8], offset: &mut usize) -> Result { + let Some(raw) = input.get(*offset..(*offset + 4)) else { + return exec_err!("invalid variant_schema_agg state: missing u32"); + }; + *offset += 4; + Ok(u32::from_le_bytes([raw[0], raw[1], raw[2], raw[3]])) +} + +fn read_len_prefixed_bytes<'a>(input: &'a [u8], offset: &mut usize) -> Result<&'a [u8]> { + let len = read_u32(input, offset)? as usize; + let Some(raw) = input.get(*offset..(*offset + len)) else { + return exec_err!("invalid variant_schema_agg state: truncated payload"); + }; + *offset += len; + Ok(raw) +} + +fn encode_variant_schema(schema: &VariantSchema, out: &mut Vec) { + match schema { + VariantSchema::Primitive(dtype) => { + out.push(0); + encode_len_prefixed_bytes(out, dtype.to_string().as_bytes()); + } + VariantSchema::Array(inner) => { + out.push(1); + encode_variant_schema(inner, out); + } + VariantSchema::Object(fields) => { + out.push(2); + out.extend_from_slice(&(fields.len() as u32).to_le_bytes()); + for (key, value) in fields { + encode_len_prefixed_bytes(out, key.as_bytes()); + encode_variant_schema(value, out); + } + } + VariantSchema::Variant => out.push(3), + } +} + +fn decode_variant_schema(input: &[u8], offset: &mut usize) -> Result { + match read_u8(input, offset)? { + 0 => { + let raw = read_len_prefixed_bytes(input, offset)?; + let dtype_str = match std::str::from_utf8(raw) { + Ok(v) => v, + Err(e) => return exec_err!("invalid variant_schema_agg state: {e}"), + }; + let dtype = match dtype_str.parse::() { + Ok(v) => v, + Err(e) => return exec_err!("invalid variant_schema_agg datatype state: {e}"), + }; + Ok(VariantSchema::Primitive(dtype)) + } + 1 => Ok(VariantSchema::Array(Box::new(decode_variant_schema( + input, offset, + )?))), + 2 => { + let count = read_u32(input, offset)? as usize; + let mut fields = BTreeMap::new(); + for _ in 0..count { + let key_raw = read_len_prefixed_bytes(input, offset)?; + let key = match std::str::from_utf8(key_raw) { + Ok(v) => v.to_string(), + Err(e) => return exec_err!("invalid variant_schema_agg field key: {e}"), + }; + let value = decode_variant_schema(input, offset)?; + fields.insert(key, value); + } + Ok(VariantSchema::Object(fields)) + } + 3 => Ok(VariantSchema::Variant), + tag => exec_err!("invalid variant_schema_agg state tag: {tag}"), + } +} + +/// This function extracts the schema from a single Variant scalar +pub fn schema_from_variant(v: &Variant) -> VariantSchema { + match v { + Variant::Object(obj) => { + let fields = obj + .iter() + .map(|(k, v)| (k.to_string(), schema_from_variant(&v))) + .collect(); + + VariantSchema::Object(fields) + } + Variant::List(list) => { + let inner = list + .iter() + .map(|v| schema_from_variant(&v)) + .try_fold(VariantSchema::Primitive(DataType::Null), |acc, next| { + let merged = merge_variant_schema(acc, next); + if merged == VariantSchema::Variant { + Err(merged) + } else { + Ok(merged) + } + }) + .unwrap_or_else(|schema| schema); + + VariantSchema::Array(Box::new(inner)) + } + _ => VariantSchema::Primitive(primitive_from_variant(v)), + } +} + +/// This helper function is used to calculate decimal precision +/// for [primitive_from_variant] decimal Variants conversion +fn decimal_precision>(val: T) -> u8 { + let mut n = val.into(); + if n == 0 { + return 1; + } + if n < 0 { + n = -n + } + + let mut digits = 0; + while n != 0 { + digits += 1; + n /= 10; + } + digits +} + +/// This function is used to extract datatype from a primitive Variant +fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> DataType { + match v { + Variant::Null => DataType::Null, + Variant::Int8(_) => DataType::Int8, + Variant::Int16(_) => DataType::Int16, + Variant::Int32(_) => DataType::Int32, + Variant::Int64(_) => DataType::Int64, + Variant::Float(_) => DataType::Float32, + Variant::Double(_) => DataType::Float64, + Variant::Decimal4(d) => { + DataType::Decimal32(decimal_precision(d.integer()), d.scale() as i8) + } + Variant::Decimal8(d) => { + DataType::Decimal64(decimal_precision(d.integer()), d.scale() as i8) + } + Variant::Decimal16(d) => { + DataType::Decimal128(decimal_precision(d.integer()), d.scale() as i8) + } + Variant::BooleanTrue | Variant::BooleanFalse => DataType::Boolean, + Variant::String(_) | Variant::ShortString(_) | Variant::Uuid(_) => DataType::Utf8, + Variant::Binary(_) => DataType::Binary, + Variant::Date(_) => DataType::Date32, + Variant::Time(_) => DataType::Time64(TimeUnit::Microsecond), + Variant::TimestampMicros(_) => { + DataType::Timestamp(TimeUnit::Microsecond, Some("utc".into())) + } + Variant::TimestampNtzMicros(_) => DataType::Timestamp(TimeUnit::Microsecond, None), + Variant::TimestampNanos(_) => DataType::Timestamp(TimeUnit::Nanosecond, Some("utc".into())), + Variant::TimestampNtzNanos(_) => DataType::Timestamp(TimeUnit::Nanosecond, None), + _ => unreachable!("Should be only applied to Primitive Variant, not Object or List"), + } +} + +/// This function is used to merge types between schemas +/// and coerce them into a common type when possible if types +/// are different +fn merge_decimal_types(p1: u8, s1: i8, p2: u8, s2: i8) -> Option { + const DECIMAL128_MAX_PRECISION: i16 = 38; + + // Decimal scale is non-negative in Arrow logical types. + if s1 < 0 || s2 < 0 { + return None; + } + + let scale = s1.max(s2); + let int_digits_1 = p1 as i16 - s1 as i16; + let int_digits_2 = p2 as i16 - s2 as i16; + let int_digits = int_digits_1.max(int_digits_2); + let precision = int_digits + scale as i16; + let precision = precision.max(1); + + // Decimal128 max precision in Arrow. + if precision > DECIMAL128_MAX_PRECISION { + return None; + } + + Some(DataType::Decimal128(precision as u8, scale)) +} + +fn merge_int_and_decimal(int_min_precision: u8, p: u8, s: i8) -> Option { + merge_decimal_types(int_min_precision, 0, p, s) +} + +fn merge_primitives(a: DataType, b: DataType) -> Option { + use DataType::*; + const MIN_DECIMAL_PRECISION_FOR_INT8_INT16_INT32: u8 = 10; + const MIN_DECIMAL_PRECISION_FOR_INT64: u8 = 20; + + match (a, b) { + (x, y) if x == y => Some(x), + // numeric widening + // docs.databricks.com/aws/en/sql/language-manual/sql-ref-datatype-rules#type-precedence-list + // For least common type resolution FLOAT is skipped to avoid loss of precision. + (Int8 | Int16 | Int32 | Int64 | Float32, Float64) + | (Float64, Int8 | Int16 | Int32 | Int64 | Float32) => Some(Float64), + (Int8 | Int16 | Int32, Int64) | (Int64, Int8 | Int16 | Int32) => Some(Int64), + (Int8 | Int16, Int32) | (Int32, Int8 | Int16) => Some(Int32), + (Int8, Int16) | (Int16, Int8) => Some(Int16), + // Keep precision safety over float32 when mixing integral + float32. + (Int8 | Int16 | Int32 | Int64, Float32) | (Float32, Int8 | Int16 | Int32 | Int64) => { + Some(Float64) + } + (Timestamp(tu1, tz1), Timestamp(tu2, tz2)) => { + if tz1 != tz2 { + None + } else { + let merged_tu = + if matches!(tu1, TimeUnit::Nanosecond) || matches!(tu2, TimeUnit::Nanosecond) { + TimeUnit::Nanosecond + } else { + TimeUnit::Microsecond + }; + Some(Timestamp(merged_tu, tz1)) + } + } + // Databricks precedence list promotes DATE -> TIMESTAMP. + // Preserve the timestamp timezone annotation when present. + (Date32, Timestamp(tu, tz)) | (Timestamp(tu, tz), Date32) => Some(Timestamp(tu, tz)), + (Decimal32(p1, s1), Decimal32(p2, s2)) + | (Decimal32(p1, s1), Decimal64(p2, s2)) + | (Decimal32(p1, s1), Decimal128(p2, s2)) + | (Decimal64(p1, s1), Decimal32(p2, s2)) + | (Decimal64(p1, s1), Decimal64(p2, s2)) + | (Decimal64(p1, s1), Decimal128(p2, s2)) + | (Decimal128(p1, s1), Decimal32(p2, s2)) + | (Decimal128(p1, s1), Decimal64(p2, s2)) + | (Decimal128(p1, s1), Decimal128(p2, s2)) => merge_decimal_types(p1, s1, p2, s2), + (Int8, Decimal32(p, s)) + | (Int8, Decimal64(p, s)) + | (Int8, Decimal128(p, s)) + | (Decimal32(p, s), Int8) + | (Decimal64(p, s), Int8) + | (Decimal128(p, s), Int8) => { + merge_int_and_decimal(MIN_DECIMAL_PRECISION_FOR_INT8_INT16_INT32, p, s) + } + (Int16, Decimal32(p, s)) + | (Int16, Decimal64(p, s)) + | (Int16, Decimal128(p, s)) + | (Decimal32(p, s), Int16) + | (Decimal64(p, s), Int16) + | (Decimal128(p, s), Int16) => { + merge_int_and_decimal(MIN_DECIMAL_PRECISION_FOR_INT8_INT16_INT32, p, s) + } + (Int32, Decimal32(p, s)) + | (Int32, Decimal64(p, s)) + | (Int32, Decimal128(p, s)) + | (Decimal32(p, s), Int32) + | (Decimal64(p, s), Int32) + | (Decimal128(p, s), Int32) => { + merge_int_and_decimal(MIN_DECIMAL_PRECISION_FOR_INT8_INT16_INT32, p, s) + } + (Int64, Decimal32(p, s)) + | (Int64, Decimal64(p, s)) + | (Int64, Decimal128(p, s)) + | (Decimal32(p, s), Int64) + | (Decimal64(p, s), Int64) + | (Decimal128(p, s), Int64) => merge_int_and_decimal(MIN_DECIMAL_PRECISION_FOR_INT64, p, s), + // Prefer floating fallback when mixing decimals with floating point values. + (Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _), Float32 | Float64) + | (Float32 | Float64, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _)) => { + Some(Float64) + } + + _ => None, + } +} + +/// Merges two inferred Variant schemas into a common schema. +/// Returns VARIANT if no common schema can be determined. +pub fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { + let mut merged = a; + merge_variant_schema_from(&mut merged, &b); + merged +} + +pub fn merge_variant_schema_from(target: &mut VariantSchema, incoming: &VariantSchema) { + use VariantSchema::*; + + if matches!(target, Variant) || matches!(incoming, Variant) { + *target = Variant; + return; + } + + if matches!(incoming, Primitive(DataType::Null)) { + return; + } + + if matches!(target, Primitive(DataType::Null)) { + *target = incoming.clone(); + return; + } + + match incoming { + Primitive(p2) => { + if let Primitive(p1) = target { + let merged = merge_primitives(p1.clone(), p2.clone()) + .map(Primitive) + .unwrap_or(Variant); + *target = merged; + } else { + *target = Variant; + } + } + Array(b) => { + if let Array(a) = target { + merge_variant_schema_from(a.as_mut(), b.as_ref()); + } else { + *target = Variant; + } + } + Object(b) => { + if let Object(a) = target { + for (k, v_b) in b { + match a.entry(k.clone()) { + Entry::Occupied(mut occ) => merge_variant_schema_from(occ.get_mut(), v_b), + Entry::Vacant(vac) => { + vac.insert(v_b.clone()); + } + } + } + } else { + *target = Variant; + } + } + Variant => { + *target = Variant; + } + } +} + +pub fn merge_variant_schema_into(target: &mut VariantSchema, incoming: VariantSchema) { + merge_variant_schema_from(target, &incoming); +} + +/// Prints schema in a presentable manner +pub fn print_schema(schema: &VariantSchema) -> String { + match schema { + VariantSchema::Primitive(s) => format!("{s}"), + + VariantSchema::Variant => "VARIANT".to_string(), + + VariantSchema::Array(inner) => { + format!("ARRAY<{}>", print_schema(inner)) + } + + VariantSchema::Object(fields) => { + let parts: Vec = fields + .iter() + .map(|(k, v)| format!("{k}: {}", print_schema(v))) + .collect(); + format!("OBJECT<{}>", parts.join(", ")) + } + } +} + +/// Retrieve schema text from a VARIANT scalar or array (row-wise for arrays). +fn infer_variant_schema(variant: &ColumnarValue) -> Result { + match variant { + ColumnarValue::Scalar(scalar) => { + let ScalarValue::Struct(struct_array) = scalar else { + return exec_err!("Unsupported data type: {}", scalar.data_type()); + }; + + let variant_array = VariantArray::try_new(struct_array.as_ref())?; + let v = variant_array.value(0); + let schema = schema_from_variant(&v); + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + print_schema(&schema), + )))) + } + ColumnarValue::Array(array) => { + let variant_array = VariantArray::try_new(array.as_ref())?; + let out = variant_array + .iter() + .map(|v| v.map(|v| print_schema(&schema_from_variant(&v)))) + .collect::>(); + + let out: StringViewArray = out.into(); + Ok(ColumnarValue::Array(Arc::new(out) as ArrayRef)) + } + } +} + +impl ScalarUDFImpl for VariantSchemaUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "variant_schema" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8View) + } + + fn invoke_with_args( + &self, + args: datafusion::logical_expr::ScalarFunctionArgs, + ) -> Result { + let arg = args.args.first().ok_or_else(|| { + DataFusionError::Execution("empty argument, expected 1 argument".to_string()) + })?; + infer_variant_schema(arg) + } +} diff --git a/src/variant_schema_agg.rs b/src/variant_schema_agg.rs new file mode 100644 index 0000000..b9bc1dc --- /dev/null +++ b/src/variant_schema_agg.rs @@ -0,0 +1,266 @@ +use arrow::array::AsArray; +use arrow_schema::{DataType, Field, FieldRef}; +use datafusion::{ + error::Result, + logical_expr::{ + Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, + function::{AccumulatorArgs, StateFieldsArgs}, + utils::format_state_name, + }, + scalar::ScalarValue, +}; +use parquet_variant_compute::VariantArray; +use std::sync::Arc; + +use crate::{ + VariantSchema, merge_variant_schema_from, print_schema, schema_from_variant, + shared::try_parse_binary_columnar, +}; + +/// Aggregate schema inference for VARIANT values across rows. +/// +/// This function infers per-row schemas using `schema_from_variant` and merges +/// them into a single schema per group. +/// +/// Semantics: +/// - Input: one VARIANT expression +/// - Output: one schema string per aggregate group +/// - Row filtering should be done via SQL `FILTER (WHERE ...)` +/// +/// Use `variant_schema` for row-wise (non-aggregate) inference. +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct VariantSchemaAggUDAF { + signature: Signature, +} + +impl Default for VariantSchemaAggUDAF { + fn default() -> Self { + Self { + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for VariantSchemaAggUDAF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "variant_schema_agg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8View) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let fields = vec![Arc::new(Field::new( + format_state_name(args.name, "variant_schema"), + DataType::Binary, + true, + ))]; + + Ok(fields + .into_iter() + .chain(args.ordering_fields.to_vec()) + .collect()) + } + + fn accumulator( + &self, + acc_args: datafusion::logical_expr::function::AccumulatorArgs, + ) -> Result> { + Ok(Box::new(VariantSchemaAccumulator::new(acc_args))) + } +} + +/// Accumulator state for `variant_schema_agg`. +#[derive(Debug)] +pub struct VariantSchemaAccumulator { + schema: VariantSchema, +} + +impl VariantSchemaAccumulator { + fn new(_acc_args: AccumulatorArgs) -> Self { + Self { + schema: VariantSchema::Primitive(DataType::Null), + } + } +} + +impl Accumulator for VariantSchemaAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Binary(Some( + self.schema.to_state_bytes(), + ))]) + } + + fn evaluate(&mut self) -> Result { + // Return the schema as a Utf8 representation + Ok(ScalarValue::Utf8View(Some(print_schema(&self.schema)))) + } + + fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> Result<()> { + if self.schema == VariantSchema::Variant { + return Ok(()); + } + + // We're assuming the input is an array of variants + for value in values { + // Ensure we are dealing with VariantArray and extract the variant values + let variant_array = VariantArray::try_new(value.as_struct())?; + for variant in variant_array.iter().flatten() { + let new_schema = schema_from_variant(&variant); + // Merge the new schema with the current schema + merge_variant_schema_from(&mut self.schema, &new_schema); + if self.schema == VariantSchema::Variant { + return Ok(()); + } + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> Result<()> { + if self.schema == VariantSchema::Variant { + return Ok(()); + } + + for state in states { + for encoded_state in try_parse_binary_columnar(state)?.into_iter().flatten() { + let new_schema = VariantSchema::from_state_bytes(encoded_state)?; + merge_variant_schema_from(&mut self.schema, &new_schema); + if self.schema == VariantSchema::Variant { + return Ok(()); + } + } + } + Ok(()) + } + + fn size(&self) -> usize { + // The size is essentially the number of variants processed, if needed + 1 // This could be expanded to return a more useful size + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::ArrayRef; + use arrow_schema::{DataType, Field, Fields, Schema}; + use datafusion::{ + logical_expr::{Accumulator, function::AccumulatorArgs}, + physical_expr::PhysicalSortExpr, + physical_plan::expressions::col, + scalar::ScalarValue, + }; + use parquet_variant_compute::VariantType; + + use crate::{ + shared::build_variant_array_from_json_array, variant_schema_agg::VariantSchemaAccumulator, + }; + + #[test] + fn test_merge_batch_from_state_roundtrip() { + let schema = Schema::new(vec![ + Field::new( + "b", + DataType::Struct(Fields::from(vec![ + Field::new("metadata", DataType::Binary, true), + Field::new("value", DataType::Binary, true), + ])), + true, + ) + .with_extension_type(VariantType), + ]); + + let b1 = build_variant_array_from_json_array(&[Some(serde_json::json!({"a": 1}))]); + let b1: ArrayRef = Arc::new(b1.into_inner()); + + let b2 = build_variant_array_from_json_array(&[Some(serde_json::json!({"a": 2.5}))]); + let b2: ArrayRef = Arc::new(b2.into_inner()); + + let expr = col("b", &schema).unwrap(); + let order_bys = vec![PhysicalSortExpr::new_default(Arc::clone(&expr))]; + let exprs = vec![expr]; + let expr_fields = vec![Arc::new( + Field::new( + "b", + DataType::Struct(Fields::from(vec![ + Field::new("metadata", DataType::Binary, true), + Field::new("value", DataType::Binary, true), + ])), + true, + ) + .with_extension_type(VariantType), + )]; + + let acc1_args = AccumulatorArgs { + return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), + schema: &schema, + ignore_nulls: false, + order_bys: &order_bys, + is_reversed: false, + name: "variant_schema_agg", + is_distinct: false, + exprs: &exprs, + expr_fields: &expr_fields, + }; + let acc2_args = AccumulatorArgs { + return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), + schema: &schema, + ignore_nulls: false, + order_bys: &order_bys, + is_reversed: false, + name: "variant_schema_agg", + is_distinct: false, + exprs: &exprs, + expr_fields: &expr_fields, + }; + let merged_args = AccumulatorArgs { + return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), + schema: &schema, + ignore_nulls: false, + order_bys: &order_bys, + is_reversed: false, + name: "variant_schema_agg", + is_distinct: false, + exprs: &exprs, + expr_fields: &expr_fields, + }; + + let mut acc1 = VariantSchemaAccumulator::new(acc1_args); + acc1.update_batch(&[Arc::clone(&b1)]).unwrap(); + let state_1 = acc1 + .state() + .unwrap() + .into_iter() + .map(|s| s.to_array().unwrap()) + .collect::>(); + + let mut acc2 = VariantSchemaAccumulator::new(acc2_args); + acc2.update_batch(&[Arc::clone(&b2)]).unwrap(); + let state_2 = acc2 + .state() + .unwrap() + .into_iter() + .map(|s| s.to_array().unwrap()) + .collect::>(); + + let mut merged = VariantSchemaAccumulator::new(merged_args); + merged.merge_batch(&state_1).unwrap(); + merged.merge_batch(&state_2).unwrap(); + + assert_eq!( + merged.evaluate().unwrap(), + ScalarValue::Utf8View(Some("OBJECT".to_string())) + ); + } +} diff --git a/tests/sqllogictests.rs b/tests/sqllogictests.rs index 65ce5aa..02af811 100644 --- a/tests/sqllogictests.rs +++ b/tests/sqllogictests.rs @@ -1,9 +1,12 @@ -use datafusion::{logical_expr::ScalarUDF, prelude::*}; +use datafusion::{ + logical_expr::{AggregateUDF, ScalarUDF}, + prelude::*, +}; use datafusion_sqllogictest::{DataFusion, TestContext}; use datafusion_variant::{ CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetUdf, VariantListConstruct, VariantListDelete, VariantListInsert, VariantObjectConstruct, VariantObjectDelete, - VariantObjectInsert, VariantPretty, VariantToJsonUdf, + VariantObjectInsert, VariantPretty, VariantSchemaAggUDAF, VariantSchemaUDF, VariantToJsonUdf, }; use indicatif::ProgressBar; use sqllogictest::strict_column_validator; @@ -30,7 +33,7 @@ async fn run_sqllogictests() -> Result<(), Box> { test_files.sort(); for test_file in test_files { - println!("Running test file: {:?}", test_file); + println!("Running test file: {test_file:?}"); let relative_path = test_file .strip_prefix(&test_files_dir) @@ -56,6 +59,8 @@ async fn run_sqllogictests() -> Result<(), Box> { ctx.register_udf(ScalarUDF::new_from_impl(VariantListInsert::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectInsert::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectDelete::default())); + ctx.register_udf(ScalarUDF::new_from_impl(VariantSchemaUDF::default())); + ctx.register_udaf(AggregateUDF::new_from_impl(VariantSchemaAggUDAF::default())); let pb = ProgressBar::new(24); diff --git a/tests/test_files/variant_schema.slt b/tests/test_files/variant_schema.slt new file mode 100644 index 0000000..4646f6e --- /dev/null +++ b/tests/test_files/variant_schema.slt @@ -0,0 +1,189 @@ +# tests the variant_schema udf +# this function takes a VARIANT expression +# and extracts each row's SQL schema + +# simple example with a scalar value +query T +SELECT variant_schema(json_to_variant('{"key": 123, "data": [4, 5]}')) +---- +OBJECT, key: Int8> + +# column input (row-wise, non-aggregate) +statement ok +CREATE TABLE t_col AS VALUES +(json_to_variant('{"a": 1}')), +(json_to_variant('[1, 2, 3]')); + +query T +SELECT variant_schema(column1) FROM t_col ORDER BY 1; +---- +ARRAY +OBJECT + + +# conflicting element types in array +query T +SELECT variant_schema(json_to_variant('{"data": [{"a":"a"}, 5]}')) +---- +OBJECT> + +# typed literal +query T +SELECT variant_schema(json_to_variant(123.4)) +---- +Float64 + +# explicit string primitive +query T +SELECT variant_schema(json_to_variant('"foo"')) +---- +Utf8 + +# explicit boolean primitive +query T +SELECT variant_schema(json_to_variant('true')) +---- +Boolean + +# cast_to_variant typed DATE/TIME/TIMESTAMP/DECIMAL from columns +statement ok +CREATE TABLE t_typed AS +SELECT + CAST('1990-01-01' AS DATE) AS d, + CAST('00:00:00' AS TIME) AS t, + CAST('2015-05-14 00:00:00' AS TIMESTAMP) AS ts, + CAST(123.4 AS DECIMAL(4, 1)) AS decv; + +query T +SELECT variant_schema(cast_to_variant(d)) FROM t_typed; +---- +Date32 + +query T +SELECT variant_schema(cast_to_variant(t)) FROM t_typed; +---- +Time64(µs) + +query T +SELECT variant_schema(cast_to_variant(ts)) FROM t_typed; +---- +Timestamp(µs) + +query T +SELECT variant_schema(cast_to_variant(decv)) FROM t_typed; +---- +Decimal128(4, 1) + +# explicit null +query T +SELECT variant_schema(json_to_variant('null')) +---- +Null + +# json null +query T +SELECT variant_schema(json_to_variant('{"a": null}')) +---- +OBJECT + +# numeric widening +query T +SELECT variant_schema(json_to_variant('[1, 2.5, 3]')) +---- +ARRAY + +# typed widening: tinyint + smallint +query T +SELECT variant_schema( + variant_list_construct( + cast_to_variant(CAST(1 AS TINYINT)), + cast_to_variant(CAST(2 AS SMALLINT)) + ) +) +---- +ARRAY + +# typed widening: decimal + decimal +query T +SELECT variant_schema( + variant_list_construct( + cast_to_variant(CAST(1.2 AS DECIMAL(4, 1))), + cast_to_variant(CAST(12.345 AS DECIMAL(8, 3))) + ) +) +---- +ARRAY + +# typed widening: integer + decimal +query T +SELECT variant_schema( + variant_list_construct( + cast_to_variant(CAST(1 AS TINYINT)), + cast_to_variant(CAST(12.3 AS DECIMAL(4, 1))) + ) +) +---- +ARRAY + +# typed widening: bigint + decimal +query T +SELECT variant_schema( + variant_list_construct( + cast_to_variant(CAST(1 AS BIGINT)), + cast_to_variant(CAST(12.3 AS DECIMAL(4, 1))) + ) +) +---- +ARRAY + +# typed widening: decimal + float +query T +SELECT variant_schema( + variant_list_construct( + cast_to_variant(CAST(1.2 AS DECIMAL(4, 1))), + cast_to_variant(CAST(2.5 AS REAL)) + ) +) +---- +ARRAY + +# typed widening: date + timestamp (ntz) +query T +SELECT variant_schema( + variant_list_construct( + cast_to_variant(CAST('1990-01-01' AS DATE)), + cast_to_variant(CAST('2015-05-14 00:00:00' AS TIMESTAMP)) + ) +) +---- +ARRAY + +# array of objects +query T +SELECT variant_schema(json_to_variant('[{"a":1},{"a":2}]')) +---- +ARRAY> + +# empty object +query T +SELECT variant_schema(json_to_variant('{}')) +---- +OBJECT<> + +# empty array +query T +SELECT variant_schema(json_to_variant('[]')) +---- +ARRAY + +# field ordering +query T +SELECT variant_schema(json_to_variant('{"b":1,"a":2}')) +---- +OBJECT + +# last key wins? +query T +SELECT variant_schema(json_to_variant('{"a": 1, "a": {"b":2}}')) +---- +OBJECT> diff --git a/tests/test_files/variant_schema_agg.slt b/tests/test_files/variant_schema_agg.slt new file mode 100644 index 0000000..d05f12c --- /dev/null +++ b/tests/test_files/variant_schema_agg.slt @@ -0,0 +1,180 @@ +# tests the variant_schema_agg udaf +# this function takes a Variant Array +# and extracts it's SQL schema + +# same schema +statement ok +CREATE TABLE t as VALUES +(json_to_variant('{"foo": "bar", "wing": {"ding": "dong"}}')), +(json_to_variant('{"wing": {"ding": "man"}}')); + +query T +SELECT variant_schema_agg(column1) from t; +---- +OBJECT> + +# conflicting schema +statement ok +CREATE TABLE t_conflicting as VALUES +(json_to_variant('{"foo": "bar", "wing": {"ding": "dong"}}')), +(json_to_variant('{"wing": 123}')); + +query T +SELECT variant_schema_agg(column1) from t_conflicting; +---- +OBJECT + +# null row +statement ok +CREATE TABLE t_nulls AS VALUES +(json_to_variant('{"a": 1}')), +(json_to_variant('null')), +(json_to_variant('{"a": 2}')); + +query T +SELECT variant_schema_agg(column1) FROM t_nulls; +---- +OBJECT + +# numeric widening +statement ok +CREATE TABLE t_nums AS VALUES +(json_to_variant('{"a": 1}')), +(json_to_variant('{"a": 2.5}')); + +query T +SELECT variant_schema_agg(column1) FROM t_nums; +---- +OBJECT + +# typed widening: tinyint + smallint +statement ok +CREATE TABLE t_small_widen AS VALUES +(cast_to_variant(CAST(1 AS TINYINT))), +(cast_to_variant(CAST(2 AS SMALLINT))); + +query T +SELECT variant_schema_agg(column1) FROM t_small_widen; +---- +Int16 + +# typed widening: decimal + decimal +statement ok +CREATE TABLE t_dec_widen AS VALUES +(cast_to_variant(CAST(1.2 AS DECIMAL(4, 1)))), +(cast_to_variant(CAST(12.345 AS DECIMAL(8, 3)))); + +query T +SELECT variant_schema_agg(column1) FROM t_dec_widen; +---- +Decimal128(5, 3) + +# typed widening: integer + decimal +statement ok +CREATE TABLE t_int_dec_widen AS VALUES +(cast_to_variant(CAST(1 AS TINYINT))), +(cast_to_variant(CAST(12.3 AS DECIMAL(4, 1)))); + +query T +SELECT variant_schema_agg(column1) FROM t_int_dec_widen; +---- +Decimal128(11, 1) + +# typed widening: bigint + decimal +statement ok +CREATE TABLE t_bigint_dec_widen AS VALUES +(cast_to_variant(CAST(1 AS BIGINT))), +(cast_to_variant(CAST(12.3 AS DECIMAL(4, 1)))); + +query T +SELECT variant_schema_agg(column1) FROM t_bigint_dec_widen; +---- +Decimal128(21, 1) + +# typed widening: decimal + float +statement ok +CREATE TABLE t_dec_float_widen AS VALUES +(cast_to_variant(CAST(1.2 AS DECIMAL(4, 1)))), +(cast_to_variant(CAST(2.5 AS REAL))); + +query T +SELECT variant_schema_agg(column1) FROM t_dec_float_widen; +---- +Float64 + +# typed widening: date + timestamp (ntz) +statement ok +CREATE TABLE t_date_ts_widen AS VALUES +(cast_to_variant(CAST('1990-01-01' AS DATE))), +(cast_to_variant(CAST('2015-05-14 00:00:00' AS TIMESTAMP))); + +query T +SELECT variant_schema_agg(column1) FROM t_date_ts_widen; +---- +Timestamp(µs) + +# field appears later +statement ok +CREATE TABLE t_sparse AS VALUES +(json_to_variant('{}')), +(json_to_variant('{"a": 1}')); + +query T +SELECT variant_schema_agg(column1) FROM t_sparse; +---- +OBJECT + +# conflicting array of objects +statement ok +CREATE TABLE t_arr_objs AS VALUES +(json_to_variant('[{"a":1}]')), +(json_to_variant('[{"a":"x"}]')); + +query T +SELECT variant_schema_agg(column1) FROM t_arr_objs; +---- +ARRAY> + +# empty aggregates +statement ok +CREATE TABLE t_empty AS VALUES +(json_to_variant('{}')), +(json_to_variant('{}')); + +query T +SELECT variant_schema_agg(column1) FROM t_empty; +---- +OBJECT<> + +# field ordering +statement ok +CREATE TABLE t_order AS VALUES +(json_to_variant('{"b":1}')), +(json_to_variant('{"a":2}')); + +query T +SELECT variant_schema_agg(column1) FROM t_order; +---- +OBJECT + +# root conflict +statement ok +CREATE TABLE t_root_conflict AS VALUES +(json_to_variant('{"a":1}')), +(json_to_variant('[1,2,3]')); + +query T +SELECT variant_schema_agg(column1) FROM t_root_conflict; +---- +VARIANT + +# mixed root +statement ok +CREATE TABLE t_mixed AS VALUES +(json_to_variant('1')), +(json_to_variant('{"a": 1}')); + +query T +SELECT variant_schema_agg(column1) FROM t_mixed; +---- +VARIANT