18
18
use crate :: interleave:: interleave;
19
19
use ahash:: RandomState ;
20
20
use arrow_array:: builder:: BooleanBufferBuilder ;
21
- use arrow_array:: cast:: AsArray ;
22
21
use arrow_array:: types:: {
23
- ArrowDictionaryKeyType , BinaryType , ByteArrayType , LargeBinaryType , LargeUtf8Type , Utf8Type ,
22
+ ArrowDictionaryKeyType , ArrowPrimitiveType , BinaryType , ByteArrayType , LargeBinaryType ,
23
+ LargeUtf8Type , Utf8Type ,
24
24
} ;
25
- use arrow_array:: { Array , ArrayRef , DictionaryArray , GenericByteArray } ;
26
- use arrow_buffer:: { ArrowNativeType , BooleanBuffer , ScalarBuffer } ;
25
+ use arrow_array:: { cast:: AsArray , downcast_primitive} ;
26
+ use arrow_array:: { Array , ArrayRef , DictionaryArray , GenericByteArray , PrimitiveArray } ;
27
+ use arrow_buffer:: { ArrowNativeType , BooleanBuffer , ScalarBuffer , ToByteSlice } ;
27
28
use arrow_schema:: { ArrowError , DataType } ;
28
29
29
30
/// A best effort interner that maintains a fixed number of buckets
@@ -102,7 +103,7 @@ fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
102
103
}
103
104
104
105
/// A type-erased function that compares two array for pointer equality
105
- type PtrEq = dyn Fn ( & dyn Array , & dyn Array ) -> bool ;
106
+ type PtrEq = fn ( & dyn Array , & dyn Array ) -> bool ;
106
107
107
108
/// A weak heuristic of whether to merge dictionary values that aims to only
108
109
/// perform the expensive merge computation when it is likely to yield at least
@@ -115,12 +116,17 @@ pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
115
116
) -> bool {
116
117
use DataType :: * ;
117
118
let first_values = dictionaries[ 0 ] . values ( ) . as_ref ( ) ;
118
- let ptr_eq: Box < PtrEq > = match first_values. data_type ( ) {
119
- Utf8 => Box :: new ( bytes_ptr_eq :: < Utf8Type > ) ,
120
- LargeUtf8 => Box :: new ( bytes_ptr_eq :: < LargeUtf8Type > ) ,
121
- Binary => Box :: new ( bytes_ptr_eq :: < BinaryType > ) ,
122
- LargeBinary => Box :: new ( bytes_ptr_eq :: < LargeBinaryType > ) ,
123
- _ => return false ,
119
+ let ptr_eq: PtrEq = match first_values. data_type ( ) {
120
+ Utf8 => bytes_ptr_eq :: < Utf8Type > ,
121
+ LargeUtf8 => bytes_ptr_eq :: < LargeUtf8Type > ,
122
+ Binary => bytes_ptr_eq :: < BinaryType > ,
123
+ LargeBinary => bytes_ptr_eq :: < LargeBinaryType > ,
124
+ dt => {
125
+ if !dt. is_primitive ( ) {
126
+ return false ;
127
+ }
128
+ |a, b| a. to_data ( ) . ptr_eq ( & b. to_data ( ) )
129
+ }
124
130
} ;
125
131
126
132
let mut single_dictionary = true ;
@@ -233,17 +239,43 @@ fn compute_values_mask<K: ArrowNativeType>(
233
239
builder. finish ( )
234
240
}
235
241
242
+ /// Process primitive array values to bytes
243
+ fn masked_primitives_to_bytes < ' a , T : ArrowPrimitiveType > (
244
+ array : & ' a PrimitiveArray < T > ,
245
+ mask : & BooleanBuffer ,
246
+ ) -> Vec < ( usize , Option < & ' a [ u8 ] > ) >
247
+ where
248
+ T :: Native : ToByteSlice ,
249
+ {
250
+ let mut out = Vec :: with_capacity ( mask. count_set_bits ( ) ) ;
251
+ let values = array. values ( ) ;
252
+ for idx in mask. set_indices ( ) {
253
+ out. push ( (
254
+ idx,
255
+ array. is_valid ( idx) . then_some ( values[ idx] . to_byte_slice ( ) ) ,
256
+ ) )
257
+ }
258
+ out
259
+ }
260
+
261
+ macro_rules! masked_primitive_to_bytes_helper {
262
+ ( $t: ty, $array: expr, $mask: expr) => {
263
+ masked_primitives_to_bytes:: <$t>( $array. as_primitive( ) , $mask)
264
+ } ;
265
+ }
266
+
236
267
/// Return a Vec containing for each set index in `mask`, the index and byte value of that index
237
268
fn get_masked_values < ' a > (
238
269
array : & ' a dyn Array ,
239
270
mask : & BooleanBuffer ,
240
271
) -> Vec < ( usize , Option < & ' a [ u8 ] > ) > {
241
- match array. data_type ( ) {
272
+ downcast_primitive ! {
273
+ array. data_type( ) => ( masked_primitive_to_bytes_helper, array, mask) ,
242
274
DataType :: Utf8 => masked_bytes( array. as_string:: <i32 >( ) , mask) ,
243
275
DataType :: LargeUtf8 => masked_bytes( array. as_string:: <i64 >( ) , mask) ,
244
276
DataType :: Binary => masked_bytes( array. as_binary:: <i32 >( ) , mask) ,
245
277
DataType :: LargeBinary => masked_bytes( array. as_binary:: <i64 >( ) , mask) ,
246
- _ => unimplemented ! ( ) ,
278
+ _ => unimplemented!( "Dictionary merging for type {} is not implemented" , array . data_type ( ) ) ,
247
279
}
248
280
}
249
281
0 commit comments