Skip to content

Commit ef91857

Browse files
authored
arrow-select: add support for merging primitive dictionary values (#7519)
Previously, should_merge_dictionaries would always return false in the ptr_eq closure creation match arm for types that were not {Large}{Utf8,Binary}. This could lead to excessive memory usage. # Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #7518 # What changes are included in this PR? Update to the match arm in `should_merge_dictionary_values` to not short circuit on primitive types. Also uses primitive byte representations to reuse the `Interner` pipeline used for the bytes types. # Are there any user-facing changes? No
1 parent f92ff18 commit ef91857

File tree

2 files changed

+88
-13
lines changed

2 files changed

+88
-13
lines changed

arrow-select/src/concat.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,49 @@ mod tests {
10811081
assert!((30..40).contains(&values_len), "{values_len}")
10821082
}
10831083

1084+
#[test]
1085+
fn test_primitive_dictionary_merge() {
1086+
// Same value repeated 5 times.
1087+
let keys = vec![1; 5];
1088+
let values = (10..20).collect::<Vec<_>>();
1089+
let dict = DictionaryArray::new(
1090+
Int8Array::from(keys.clone()),
1091+
Arc::new(Int32Array::from(values.clone())),
1092+
);
1093+
let other = DictionaryArray::new(
1094+
Int8Array::from(keys.clone()),
1095+
Arc::new(Int32Array::from(values.clone())),
1096+
);
1097+
1098+
let result_same_dictionary = concat(&[&dict, &dict]).unwrap();
1099+
// Verify pointer equality check succeeds, and therefore the
1100+
// dictionaries are not merged. A single values buffer should be reused
1101+
// in this case.
1102+
assert!(dict.values().to_data().ptr_eq(
1103+
&result_same_dictionary
1104+
.as_dictionary::<Int8Type>()
1105+
.values()
1106+
.to_data()
1107+
));
1108+
assert_eq!(
1109+
result_same_dictionary
1110+
.as_dictionary::<Int8Type>()
1111+
.values()
1112+
.len(),
1113+
values.len(),
1114+
);
1115+
1116+
let result_cloned_dictionary = concat(&[&dict, &other]).unwrap();
1117+
// Should have only 1 underlying value since all keys reference it.
1118+
assert_eq!(
1119+
result_cloned_dictionary
1120+
.as_dictionary::<Int8Type>()
1121+
.values()
1122+
.len(),
1123+
1
1124+
);
1125+
}
1126+
10841127
#[test]
10851128
fn test_concat_string_sizes() {
10861129
let a: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect();

arrow-select/src/dictionary.rs

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
use crate::interleave::interleave;
1919
use ahash::RandomState;
2020
use arrow_array::builder::BooleanBufferBuilder;
21-
use arrow_array::cast::AsArray;
2221
use arrow_array::types::{
23-
ArrowDictionaryKeyType, BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type,
22+
ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType,
23+
LargeUtf8Type, Utf8Type,
2424
};
25-
use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray};
26-
use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer};
25+
use arrow_array::{cast::AsArray, downcast_primitive};
26+
use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray};
27+
use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice};
2728
use arrow_schema::{ArrowError, DataType};
2829

2930
/// A best effort interner that maintains a fixed number of buckets
@@ -102,7 +103,7 @@ fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
102103
}
103104

104105
/// A type-erased function that compares two array for pointer equality
105-
type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool;
106+
type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
106107

107108
/// A weak heuristic of whether to merge dictionary values that aims to only
108109
/// perform the expensive merge computation when it is likely to yield at least
@@ -115,12 +116,17 @@ pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
115116
) -> bool {
116117
use DataType::*;
117118
let first_values = dictionaries[0].values().as_ref();
118-
let ptr_eq: Box<PtrEq> = match first_values.data_type() {
119-
Utf8 => Box::new(bytes_ptr_eq::<Utf8Type>),
120-
LargeUtf8 => Box::new(bytes_ptr_eq::<LargeUtf8Type>),
121-
Binary => Box::new(bytes_ptr_eq::<BinaryType>),
122-
LargeBinary => Box::new(bytes_ptr_eq::<LargeBinaryType>),
123-
_ => return false,
119+
let ptr_eq: PtrEq = match first_values.data_type() {
120+
Utf8 => bytes_ptr_eq::<Utf8Type>,
121+
LargeUtf8 => bytes_ptr_eq::<LargeUtf8Type>,
122+
Binary => bytes_ptr_eq::<BinaryType>,
123+
LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
124+
dt => {
125+
if !dt.is_primitive() {
126+
return false;
127+
}
128+
|a, b| a.to_data().ptr_eq(&b.to_data())
129+
}
124130
};
125131

126132
let mut single_dictionary = true;
@@ -233,17 +239,43 @@ fn compute_values_mask<K: ArrowNativeType>(
233239
builder.finish()
234240
}
235241

242+
/// Process primitive array values to bytes
243+
fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>(
244+
array: &'a PrimitiveArray<T>,
245+
mask: &BooleanBuffer,
246+
) -> Vec<(usize, Option<&'a [u8]>)>
247+
where
248+
T::Native: ToByteSlice,
249+
{
250+
let mut out = Vec::with_capacity(mask.count_set_bits());
251+
let values = array.values();
252+
for idx in mask.set_indices() {
253+
out.push((
254+
idx,
255+
array.is_valid(idx).then_some(values[idx].to_byte_slice()),
256+
))
257+
}
258+
out
259+
}
260+
261+
macro_rules! masked_primitive_to_bytes_helper {
262+
($t:ty, $array:expr, $mask:expr) => {
263+
masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask)
264+
};
265+
}
266+
236267
/// Return a Vec containing for each set index in `mask`, the index and byte value of that index
237268
fn get_masked_values<'a>(
238269
array: &'a dyn Array,
239270
mask: &BooleanBuffer,
240271
) -> Vec<(usize, Option<&'a [u8]>)> {
241-
match array.data_type() {
272+
downcast_primitive! {
273+
array.data_type() => (masked_primitive_to_bytes_helper, array, mask),
242274
DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
243275
DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
244276
DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
245277
DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
246-
_ => unimplemented!(),
278+
_ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()),
247279
}
248280
}
249281

0 commit comments

Comments
 (0)