Skip to content

Commit 1664f95

Browse files
committed
arrow-select: add support for merging primitive dictionary values
Previously, should_merge_dictionaries would always return false in the ptr_eq closure creation match arm for types that were not {Large}{Utf8,Binary}. This could lead to excessive memory usage.
1 parent 950f4d0 commit 1664f95

File tree

2 files changed

+88
-13
lines changed

2 files changed

+88
-13
lines changed

arrow-select/src/concat.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,49 @@ mod tests {
10811081
assert!((30..40).contains(&values_len), "{values_len}")
10821082
}
10831083

1084+
#[test]
1085+
fn test_primitive_dictionary_merge() {
1086+
// Same value repeated 5 times.
1087+
let keys = vec![1; 5];
1088+
let values = (10..20).collect::<Vec<_>>();
1089+
let dict = DictionaryArray::new(
1090+
Int8Array::from(keys.clone()),
1091+
Arc::new(Int32Array::from(values.clone())),
1092+
);
1093+
let other = DictionaryArray::new(
1094+
Int8Array::from(keys.clone()),
1095+
Arc::new(Int32Array::from(values.clone())),
1096+
);
1097+
1098+
let result_same_dictionary = concat(&[&dict, &dict]).unwrap();
1099+
// Verify pointer equality check succeeds, and therefore the
1100+
// dictionaries are not merged. A single values buffer should be reused
1101+
// in this case.
1102+
assert!(dict.values().to_data().ptr_eq(
1103+
&result_same_dictionary
1104+
.as_dictionary::<Int8Type>()
1105+
.values()
1106+
.to_data()
1107+
));
1108+
assert_eq!(
1109+
result_same_dictionary
1110+
.as_dictionary::<Int8Type>()
1111+
.values()
1112+
.len(),
1113+
values.len(),
1114+
);
1115+
1116+
let result_cloned_dictionary = concat(&[&dict, &other]).unwrap();
1117+
// Should have only 1 underlying value since all keys reference it.
1118+
assert_eq!(
1119+
result_cloned_dictionary
1120+
.as_dictionary::<Int8Type>()
1121+
.values()
1122+
.len(),
1123+
1
1124+
);
1125+
}
1126+
10841127
#[test]
10851128
fn test_concat_string_sizes() {
10861129
let a: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect();

arrow-select/src/dictionary.rs

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
use crate::interleave::interleave;
1919
use ahash::RandomState;
2020
use arrow_array::builder::BooleanBufferBuilder;
21-
use arrow_array::cast::AsArray;
2221
use arrow_array::types::{
23-
ArrowDictionaryKeyType, BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type,
22+
ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType,
23+
LargeUtf8Type, Utf8Type,
2424
};
25-
use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray};
26-
use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer};
25+
use arrow_array::{cast::AsArray, downcast_primitive};
26+
use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray};
27+
use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice};
2728
use arrow_schema::{ArrowError, DataType};
2829

2930
/// A best effort interner that maintains a fixed number of buckets
@@ -102,7 +103,7 @@ fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
102103
}
103104

104105
/// A type-erased function that compares two array for pointer equality
105-
type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool;
106+
type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
106107

107108
/// A weak heuristic of whether to merge dictionary values that aims to only
108109
/// perform the expensive merge computation when it is likely to yield at least
@@ -115,12 +116,17 @@ pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
115116
) -> bool {
116117
use DataType::*;
117118
let first_values = dictionaries[0].values().as_ref();
118-
let ptr_eq: Box<PtrEq> = match first_values.data_type() {
119-
Utf8 => Box::new(bytes_ptr_eq::<Utf8Type>),
120-
LargeUtf8 => Box::new(bytes_ptr_eq::<LargeUtf8Type>),
121-
Binary => Box::new(bytes_ptr_eq::<BinaryType>),
122-
LargeBinary => Box::new(bytes_ptr_eq::<LargeBinaryType>),
123-
_ => return false,
119+
let ptr_eq: PtrEq = match first_values.data_type() {
120+
Utf8 => bytes_ptr_eq::<Utf8Type>,
121+
LargeUtf8 => bytes_ptr_eq::<LargeUtf8Type>,
122+
Binary => bytes_ptr_eq::<BinaryType>,
123+
LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
124+
dt => {
125+
if !dt.is_primitive() {
126+
return false;
127+
}
128+
|a, b| a.to_data().ptr_eq(&b.to_data())
129+
}
124130
};
125131

126132
let mut single_dictionary = true;
@@ -233,17 +239,43 @@ fn compute_values_mask<K: ArrowNativeType>(
233239
builder.finish()
234240
}
235241

242+
/// Process primitive array values to bytes
243+
fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>(
244+
array: &'a PrimitiveArray<T>,
245+
mask: &BooleanBuffer,
246+
) -> Vec<(usize, Option<&'a [u8]>)>
247+
where
248+
T::Native: ToByteSlice,
249+
{
250+
let mut out = Vec::with_capacity(mask.count_set_bits());
251+
let values = array.values();
252+
for idx in mask.set_indices() {
253+
out.push((
254+
idx,
255+
array.is_valid(idx).then_some(values[idx].to_byte_slice()),
256+
))
257+
}
258+
out
259+
}
260+
261+
macro_rules! masked_primitive_to_bytes_helper {
262+
($t:ty, $array:expr, $mask:expr) => {
263+
masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask)
264+
};
265+
}
266+
236267
/// Return a Vec containing for each set index in `mask`, the index and byte value of that index
237268
fn get_masked_values<'a>(
238269
array: &'a dyn Array,
239270
mask: &BooleanBuffer,
240271
) -> Vec<(usize, Option<&'a [u8]>)> {
241-
match array.data_type() {
272+
downcast_primitive! {
273+
array.data_type() => (masked_primitive_to_bytes_helper, array, mask),
242274
DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
243275
DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
244276
DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
245277
DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
246-
_ => unimplemented!(),
278+
_ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()),
247279
}
248280
}
249281

0 commit comments

Comments
 (0)