-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: try merging dictionary as a fallback on overflow error #8652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 34 commits
fc36ee6
e18080c
af5c581
661a0ee
c56354f
a0a607d
549b8da
249422f
2ca0ef7
b0f5e9c
8da747d
f1b5e4d
d18f6f0
46554d3
a5f192e
96bef55
77f5371
eff7926
2331977
2ea49cc
3e0fafb
01a7bd4
319074e
b393332
25e20de
3c2d130
24edfa7
68377e1
17bc58a
797e236
14cda31
2a9b544
407434d
eaab4ff
89777ec
1d49d2c
a0e1e30
1998486
a1b10c9
042fca8
b3124d4
50df8f4
df12d54
7986107
bf1095e
8a5a028
1746bcf
9903f59
86677e6
8f44ec1
8390530
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| use std::collections::HashMap; | ||
|
|
||
| use arrow_buffer::ArrowNativeType; | ||
| use arrow_schema::{ArrowError, DataType}; | ||
|
|
||
| use crate::{ | ||
| ArrayData, | ||
| transform::{_MutableArrayData, Extend, MutableArrayData, utils::iter_in_bytes}, | ||
| }; | ||
|
|
||
| pub(crate) fn merge_dictionaries<'a>( | ||
| key_data_type: &DataType, | ||
| value_data_type: &DataType, | ||
| dicts: &[&'a ArrayData], | ||
| ) -> Result<(Vec<Extend<'a>>, ArrayData), ArrowError> { | ||
| match key_data_type { | ||
| DataType::UInt8 => merge_dictionaries_casted::<u8>(value_data_type, dicts), | ||
| DataType::UInt16 => merge_dictionaries_casted::<u16>(value_data_type, dicts), | ||
| DataType::UInt32 => merge_dictionaries_casted::<u32>(value_data_type, dicts), | ||
| DataType::UInt64 => merge_dictionaries_casted::<u64>(value_data_type, dicts), | ||
| DataType::Int8 => merge_dictionaries_casted::<i8>(value_data_type, dicts), | ||
| DataType::Int16 => merge_dictionaries_casted::<i16>(value_data_type, dicts), | ||
| DataType::Int32 => merge_dictionaries_casted::<i32>(value_data_type, dicts), | ||
| DataType::Int64 => merge_dictionaries_casted::<i64>(value_data_type, dicts), | ||
| _ => unreachable!(), | ||
| } | ||
| } | ||
|
|
||
| fn merge_dictionaries_casted<'a, K: ArrowNativeType>( | ||
| data_type: &DataType, | ||
| dicts: &[&'a ArrayData], | ||
| ) -> Result<(Vec<Extend<'a>>, ArrayData), ArrowError> { | ||
| let mut dedup = HashMap::new(); | ||
| let mut indices = vec![]; | ||
| let mut data_refs = vec![]; | ||
| let new_dict_keys = dicts | ||
| .iter() | ||
| .enumerate() | ||
| .map(|(dict_idx, dict)| { | ||
| let value_data = dict.child_data().first().unwrap(); | ||
| let old_keys = dict.buffer::<K>(0); | ||
| data_refs.push(value_data); | ||
| let mut new_keys = vec![K::usize_as(0); old_keys.len()]; | ||
| let values = iter_in_bytes(data_type, value_data); | ||
| for (key_index, old_key) in old_keys.iter().enumerate() { | ||
| if dict.is_valid(key_index) { | ||
| let value = values[old_key.as_usize()]; | ||
| match K::from_usize(dedup.len()) { | ||
| Some(idx) => { | ||
| let idx_for_value = dedup.entry(value).or_insert(idx); | ||
| // a new entry | ||
| if *idx_for_value == idx { | ||
| indices.push((dict_idx, old_key.as_usize())); | ||
| } | ||
|
|
||
| new_keys[key_index] = *idx_for_value; | ||
| } | ||
| // the built dictionary has reach the cap of the key type | ||
| None => match dedup.get(value) { | ||
| // as long as this value has already been indexed | ||
| // the merge dictionary is still valid | ||
| Some(previous_key) => { | ||
| new_keys[key_index] = *previous_key; | ||
| } | ||
| None => return Err(ArrowError::DictionaryKeyOverflowError), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me add more coverage on this |
||
| }, | ||
| }; | ||
| } | ||
| } | ||
|
|
||
| Ok(new_keys) | ||
| }) | ||
| .collect::<Result<Vec<Vec<K>>, ArrowError>>()?; | ||
| let shared_value_data = if indices.is_empty() { | ||
| ArrayData::new_empty(data_refs[0].data_type()) | ||
| } else { | ||
| let new_values_data = MutableArrayData::new(data_refs, false, indices.len()); | ||
| interleave(new_values_data, indices) | ||
| }; | ||
|
|
||
| Ok(( | ||
| new_dict_keys | ||
| .into_iter() | ||
| .map(|keys| { | ||
| Box::new( | ||
| move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { | ||
| mutable | ||
| .buffer1 | ||
| .extend_from_slice::<K>(&keys[start..start + len]); | ||
| }, | ||
| ) as Extend | ||
| }) | ||
| .collect::<Vec<Extend>>(), | ||
| shared_value_data, | ||
| )) | ||
| } | ||
|
|
||
| fn interleave(mut array_data: MutableArrayData, indices: Vec<(usize, usize)>) -> ArrayData { | ||
| let mut cur_array = indices[0].0; | ||
|
|
||
| let mut start_row_idx = indices[0].1; | ||
| let mut end_row_idx = start_row_idx + 1; | ||
|
|
||
| for (array, row) in indices.iter().skip(1).copied() { | ||
| if array == cur_array && row == end_row_idx { | ||
| // subsequent row in same batch | ||
| end_row_idx += 1; | ||
| continue; | ||
| } | ||
|
|
||
| // emit current batch of rows for current buffer | ||
| array_data.extend(cur_array, start_row_idx, end_row_idx); | ||
|
|
||
| // start new batch of rows | ||
| cur_array = array; | ||
| start_row_idx = row; | ||
| end_row_idx = start_row_idx + 1; | ||
| } | ||
|
|
||
| // emit final batch of rows | ||
| array_data.extend(cur_array, start_row_idx, end_row_idx); | ||
| array_data.freeze() | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,11 +25,13 @@ use crate::bit_mask::set_bits; | |
| use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; | ||
| use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer, bit_util, i256}; | ||
| use arrow_schema::{ArrowError, DataType, IntervalUnit, UnionMode}; | ||
| use dictionary::merge_dictionaries; | ||
| use half::f16; | ||
| use num_integer::Integer; | ||
| use std::mem; | ||
|
|
||
| mod boolean; | ||
| mod dictionary; | ||
| mod fixed_binary; | ||
| mod fixed_size_list; | ||
| mod list; | ||
|
|
@@ -604,7 +606,7 @@ impl<'a> MutableArrayData<'a> { | |
| }; | ||
|
|
||
| // Get the dictionary if any, and if it is a concatenation of multiple | ||
| let (dictionary, dict_concat) = match &data_type { | ||
| let (mut dictionary, dict_concat) = match &data_type { | ||
| DataType::Dictionary(_, _) => { | ||
| // If more than one dictionary, concatenate dictionaries together | ||
| let dict_concat = !arrays | ||
|
|
@@ -660,9 +662,9 @@ impl<'a> MutableArrayData<'a> { | |
| }); | ||
|
|
||
| let extend_values = match &data_type { | ||
| DataType::Dictionary(_, _) => { | ||
| DataType::Dictionary(key_data_type, value_data_type) => { | ||
| let mut next_offset = 0; | ||
| let extend_values: Result<Vec<_>, _> = arrays | ||
| let result = arrays | ||
| .iter() | ||
| .map(|array| { | ||
| let offset = next_offset; | ||
|
|
@@ -672,12 +674,24 @@ impl<'a> MutableArrayData<'a> { | |
| next_offset += dict_len; | ||
| } | ||
|
|
||
| build_extend_dictionary(array, offset, offset + dict_len) | ||
| // -1 since offset is exclusive | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this comment or hange I revered the change (venv) andrewlamb@Andrews-MacBook-Pro-3:~/Software/arrow-rs$ git diff
diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs
index 12b03bbdf0..76a116c4cd 100644
--- a/arrow-data/src/transform/mod.rs
+++ b/arrow-data/src/transform/mod.rs
@@ -674,8 +674,7 @@ impl<'a> MutableArrayData<'a> {
next_offset += dict_len;
}
- // -1 since offset is exclusive
- build_extend_dictionary(array, offset, 1.max(offset + dict_len) - 1)
+ build_extend_dictionary(array, offset, offset + dict_len)
.ok_or(ArrowError::DictionaryKeyOverflowError)
})
.collect::<Result<Vec<_>, ArrowError>>();And the tests still pass 🤔
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will trigger the error part when it reaches this function offset will be 0, dict_len is 256, and build_extend_dictionary will try cast 256 as u8, which will throw error DictionaryKeyOverflowError, while it shouldn't be. The test passes anyway because we already added a fallback for this error |
||
| build_extend_dictionary(array, offset, 1.max(offset + dict_len) - 1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than converting an Option --> Result, what about changing
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment addressed |
||
| .ok_or(ArrowError::DictionaryKeyOverflowError) | ||
| }) | ||
| .collect(); | ||
|
|
||
| extend_values.expect("MutableArrayData::new is infallible") | ||
| .collect::<Result<Vec<_>, ArrowError>>(); | ||
| match result { | ||
| Err(_) => { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should only retry when the Err is
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also think that it would help to add a comment explaining the rationale for this fallback -- namely something like "if the dictionary key overflows, it means there are too many keys in the concatenated dictionary -- in that case fall back to the slower path of merging (deduplicating) the dictionaries
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I was confused for a while about how this could detect an error as it happens when constructing the extended, not when actually running I think I understand now (it hasn't changed in this PR) -- that the maximum dictionary key is computed based on each dictionary size which make sense
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i added some more comments |
||
| let (extends, merged_dictionary_values) = merge_dictionaries( | ||
| key_data_type.as_ref(), | ||
| value_data_type.as_ref(), | ||
| &arrays, | ||
| ) | ||
| .expect("fail merging dictionary"); | ||
| dictionary = Some(merged_dictionary_values); | ||
| extends | ||
| } | ||
| Ok(extends) => extends, | ||
| } | ||
| } | ||
| DataType::BinaryView | DataType::Utf8View => { | ||
| let mut next_offset = 0u32; | ||
|
|
@@ -705,6 +719,7 @@ impl<'a> MutableArrayData<'a> { | |
| buffer2, | ||
| child_data, | ||
| }; | ||
|
|
||
| Self { | ||
| arrays, | ||
| data, | ||
|
|
@@ -841,6 +856,9 @@ mod test { | |
| use arrow_schema::Field; | ||
| use std::sync::Arc; | ||
|
|
||
| #[test] | ||
| fn test_dictionary_overflow() {} | ||
|
|
||
| #[test] | ||
| fn test_list_append_with_capacities() { | ||
| let array = ArrayData::new_empty(&DataType::List(Arc::new(Field::new( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,9 +16,12 @@ | |
| // under the License. | ||
|
|
||
| use arrow_buffer::{ArrowNativeType, MutableBuffer, bit_util}; | ||
| use arrow_schema::DataType; | ||
| use num_integer::Integer; | ||
| use num_traits::CheckedAdd; | ||
|
|
||
| use crate::ArrayData; | ||
|
|
||
| /// extends the `buffer` to be able to hold `len` bits, setting all bits of the new size to zero. | ||
| #[inline] | ||
| pub(super) fn resize_for_bits(buffer: &mut MutableBuffer, len: usize) { | ||
|
|
@@ -58,6 +61,37 @@ pub(super) unsafe fn get_last_offset<T: ArrowNativeType>(offset_buffer: &Mutable | |
| *unsafe { offsets.get_unchecked(offsets.len() - 1) } | ||
| } | ||
|
|
||
| fn iter_in_bytes_variable_sized<T: ArrowNativeType + Integer>(data: &ArrayData) -> Vec<&[u8]> { | ||
| let offsets = data.buffer::<T>(0); | ||
|
|
||
| // the offsets of the `ArrayData` are ignored as they are only applied to the offset buffer. | ||
| let values = data.buffers()[1].as_slice(); | ||
| (0..data.len()) | ||
| .map(move |i| { | ||
| let start = offsets[i].to_usize().unwrap(); | ||
| let end = offsets[i + 1].to_usize().unwrap(); | ||
| &values[start..end] | ||
| }) | ||
| .collect::<Vec<_>>() | ||
| } | ||
|
|
||
| fn iter_in_bytes_fixed_sized(data: &ArrayData, size: usize) -> Vec<&[u8]> { | ||
| let values = &data.buffers()[0].as_slice()[data.offset() * size..]; | ||
| values.chunks(size).collect::<Vec<_>>() | ||
| } | ||
|
|
||
| /// iterate values in raw bytes regardless nullability | ||
| pub(crate) fn iter_in_bytes<'a>(data_type: &DataType, data: &'a ArrayData) -> Vec<&'a [u8]> { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is called
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i renamed |
||
| if data_type.is_primitive() { | ||
| return iter_in_bytes_fixed_sized(data, data_type.primitive_width().unwrap()); | ||
| } | ||
| match data_type { | ||
| DataType::Utf8 | DataType::Binary => iter_in_bytes_variable_sized::<i32>(data), | ||
| DataType::LargeUtf8 | DataType::LargeBinary => iter_in_bytes_variable_sized::<i64>(data), | ||
| _ => unimplemented!("iter in bytes is not supported for {data_type}"), | ||
| } | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use crate::transform::utils::extend_offsets; | ||
|
|
||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we please leave some comments about what this function does and why it is needed? (aka explain the overflow backup case)