-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: try merging dictionary as a fallback on overflow error #8652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
duongcongtoai
wants to merge
51
commits into
apache:main
Choose a base branch
from
duongcongtoai:fix-overflow-on-interleave-list-of-dict
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+831
−16
Open
Changes from all commits
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
fc36ee6
fix: try merging list of dict if possible
duongcongtoai e18080c
feat: support interleaving list of struct with dict fields
duongcongtoai af5c581
test: add null test
duongcongtoai 661a0ee
feat: handle other non merged fields
duongcongtoai c56354f
test: struct list with mergable dict field
duongcongtoai a0a607d
fix: handle all dict key size
duongcongtoai 549b8da
fix: linting
duongcongtoai 249422f
test: add bench
duongcongtoai 2ca0ef7
fix: key overflow
duongcongtoai b0f5e9c
fix: naming
duongcongtoai 8da747d
test: not used test
duongcongtoai f1b5e4d
chore: doc
duongcongtoai d18f6f0
chore: add runbench
duongcongtoai 46554d3
chore: rm temp
duongcongtoai a5f192e
test: add bench for list
duongcongtoai 96bef55
fix: data type
duongcongtoai 77f5371
fix: lint
duongcongtoai eff7926
feat: best effort merge dictionary on error
duongcongtoai 2331977
feat: simplify the fallback
duongcongtoai 2ea49cc
chore: revert unrelated changes
duongcongtoai 3e0fafb
chore: rm bench script
duongcongtoai 01a7bd4
chore: lint
duongcongtoai 319074e
chore: some more comments
duongcongtoai b393332
feat: let mutablearraydata handle fallback
duongcongtoai 25e20de
fix: lint
duongcongtoai 3c2d130
fix: handle when all keys are null
duongcongtoai 24edfa7
fix: negative overflow
duongcongtoai 68377e1
fix: minor comment
duongcongtoai 17bc58a
fix: clippy
duongcongtoai 797e236
chore: more comment
duongcongtoai 14cda31
fix: clippy
duongcongtoai 2a9b544
fix: add license
duongcongtoai 407434d
fix: fmt
duongcongtoai eaab4ff
Merge branch 'main' into fix-overflow-on-interleave-list-of-dict
duongcongtoai 89777ec
fix: more comments
duongcongtoai 1d49d2c
fix: compile err
duongcongtoai a0e1e30
test: on more distinct keys
duongcongtoai 1998486
test: a case when overflow happens
duongcongtoai a1b10c9
test: use larger key size
duongcongtoai 042fca8
test: arrow-data
duongcongtoai b3124d4
Merge remote-tracking branch 'origin/main' into fix-overflow-on-inter…
duongcongtoai 50df8f4
fix: test returned extends closure
duongcongtoai df12d54
fix: better comment
duongcongtoai 7986107
fix: clippy
duongcongtoai bf1095e
fix: more cov
duongcongtoai 8a5a028
chore: use different backed value
duongcongtoai 1746bcf
Merge branch 'main' into fix-overflow-on-interleave-list-of-dict
duongcongtoai 9903f59
chore: add some comments on fallback
duongcongtoai 86677e6
Merge branch 'main' of https://github.com/apache/arrow-rs into fix-ov…
toaiduong-blip 8f44ec1
chore: some more test case to track panic in concat
duongcongtoai 8390530
fix: some more test to reproduce key overflow
duongcongtoai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,298 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| use std::collections::HashMap; | ||
|
|
||
| use arrow_buffer::ArrowNativeType; | ||
| use arrow_schema::{ArrowError, DataType}; | ||
|
|
||
| use crate::{ | ||
| ArrayData, | ||
| transform::{_MutableArrayData, Extend, MutableArrayData, utils::to_bytes_vec}, | ||
| }; | ||
|
|
||
| /// Fallback merge strategy used when optimized dictionary-merge paths cannot guarantee | ||
| /// correctness. I.e some fast-path algorithms may emit duplicate keys, which can overflow | ||
| /// the index type even if the logical keyspace is large enough. | ||
| /// | ||
| /// This implementation prioritizes correctness over speed: it performs a full scan of | ||
| /// every input dictionary’s values, ensuring a de-duplicated, exhaustively validated | ||
| /// keyspace before constructing the merged dictionary. | ||
| /// The function returns | ||
| /// - a vector of closure representing the mutation over each input dictionaries | ||
| /// - an `ArrayData` of the merged value array (not the dictionary) | ||
| pub(crate) fn merge_dictionaries<'a>( | ||
| key_data_type: &DataType, | ||
| value_data_type: &DataType, | ||
| dicts: &[&'a ArrayData], | ||
| ) -> Result<(Vec<Extend<'a>>, ArrayData), ArrowError> { | ||
| match key_data_type { | ||
| DataType::UInt8 => merge_dictionaries_casted::<u8>(value_data_type, dicts), | ||
| DataType::UInt16 => merge_dictionaries_casted::<u16>(value_data_type, dicts), | ||
| DataType::UInt32 => merge_dictionaries_casted::<u32>(value_data_type, dicts), | ||
| DataType::UInt64 => merge_dictionaries_casted::<u64>(value_data_type, dicts), | ||
| DataType::Int8 => merge_dictionaries_casted::<i8>(value_data_type, dicts), | ||
| DataType::Int16 => merge_dictionaries_casted::<i16>(value_data_type, dicts), | ||
| DataType::Int32 => merge_dictionaries_casted::<i32>(value_data_type, dicts), | ||
| DataType::Int64 => merge_dictionaries_casted::<i64>(value_data_type, dicts), | ||
| _ => unreachable!(), | ||
| } | ||
| } | ||
|
|
||
| fn merge_dictionaries_casted<'a, K: ArrowNativeType>( | ||
| data_type: &DataType, | ||
| dicts: &[&'a ArrayData], | ||
| ) -> Result<(Vec<Extend<'a>>, ArrayData), ArrowError> { | ||
| let mut dedup = HashMap::new(); | ||
| let mut indices = vec![]; | ||
| let mut data_refs = vec![]; | ||
| let new_dict_keys = dicts | ||
| .iter() | ||
| .enumerate() | ||
| .map(|(dict_idx, dict)| { | ||
| let value_data = dict.child_data().first().unwrap(); | ||
| let old_keys = dict.buffer::<K>(0); | ||
| data_refs.push(value_data); | ||
| let mut new_keys = vec![K::usize_as(0); old_keys.len()]; | ||
| let values = to_bytes_vec(data_type, value_data); | ||
| for (key_index, old_key) in old_keys.iter().enumerate() { | ||
| if dict.is_valid(key_index) { | ||
| let value = values[old_key.as_usize()]; | ||
| match K::from_usize(dedup.len()) { | ||
| Some(idx) => { | ||
| let idx_for_value = dedup.entry(value).or_insert(idx); | ||
| // a new entry | ||
| if *idx_for_value == idx { | ||
| indices.push((dict_idx, old_key.as_usize())); | ||
| } | ||
|
|
||
| new_keys[key_index] = *idx_for_value; | ||
| } | ||
| // the built dictionary has reach the cap of the key type | ||
| None => match dedup.get(value) { | ||
| // as long as this value has already been indexed | ||
| // the merge dictionary is still valid | ||
| Some(previous_key) => { | ||
| new_keys[key_index] = *previous_key; | ||
| } | ||
| None => return Err(ArrowError::DictionaryKeyOverflowError), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me add more coverage on this |
||
| }, | ||
| }; | ||
| } | ||
| } | ||
|
|
||
| Ok(new_keys) | ||
| }) | ||
| .collect::<Result<Vec<Vec<K>>, ArrowError>>()?; | ||
| let shared_value_data = if indices.is_empty() { | ||
| ArrayData::new_empty(data_refs[0].data_type()) | ||
| } else { | ||
| let new_values_data = MutableArrayData::new(data_refs, false, indices.len()); | ||
| interleave(new_values_data, indices) | ||
| }; | ||
|
|
||
| Ok(( | ||
| new_dict_keys | ||
| .into_iter() | ||
| .map(|keys| { | ||
| Box::new( | ||
| move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { | ||
| mutable | ||
| .buffer1 | ||
| .extend_from_slice::<K>(&keys[start..start + len]); | ||
| }, | ||
| ) as Extend | ||
| }) | ||
| .collect::<Vec<Extend>>(), | ||
| shared_value_data, | ||
| )) | ||
| } | ||
|
|
||
| fn interleave(mut array_data: MutableArrayData, indices: Vec<(usize, usize)>) -> ArrayData { | ||
| let mut cur_array = indices[0].0; | ||
|
|
||
| let mut start_row_idx = indices[0].1; | ||
| let mut end_row_idx = start_row_idx + 1; | ||
|
|
||
| for (array, row) in indices.iter().skip(1).copied() { | ||
| if array == cur_array && row == end_row_idx { | ||
| // subsequent row in same batch | ||
| end_row_idx += 1; | ||
| continue; | ||
| } | ||
|
|
||
| // emit current batch of rows for current buffer | ||
| array_data.extend(cur_array, start_row_idx, end_row_idx); | ||
|
|
||
| // start new batch of rows | ||
| cur_array = array; | ||
| start_row_idx = row; | ||
| end_row_idx = start_row_idx + 1; | ||
| } | ||
|
|
||
| // emit final batch of rows | ||
| array_data.extend(cur_array, start_row_idx, end_row_idx); | ||
| array_data.freeze() | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use std::iter::once; | ||
|
|
||
| use arrow_buffer::{ArrowNativeType, Buffer, ToByteSlice}; | ||
| use arrow_schema::{ArrowError, DataType}; | ||
|
|
||
| use crate::{ | ||
| ArrayData, new_buffers, | ||
| transform::{_MutableArrayData, dictionary::merge_dictionaries}, | ||
| }; | ||
|
|
||
| fn create_dictionary_from_value_data<K: ArrowNativeType>( | ||
| keys: Vec<K>, | ||
| value: ArrayData, | ||
| key_type: DataType, | ||
| ) -> ArrayData { | ||
| let keys_buffer = Buffer::from(keys.to_byte_slice()); | ||
|
|
||
| let dict_data_type = | ||
| DataType::Dictionary(Box::new(key_type), Box::new(value.data_type().clone())); | ||
| ArrayData::builder(dict_data_type.clone()) | ||
| .len(3) | ||
| .add_buffer(keys_buffer) | ||
| .add_child_data(value) | ||
| .build() | ||
| .unwrap() | ||
| } | ||
| fn create_dictionary<K: ArrowNativeType, V: ArrowNativeType>( | ||
| keys: Vec<K>, | ||
| value: Vec<V>, | ||
| key_type: DataType, | ||
| value_type: DataType, | ||
| ) -> ArrayData { | ||
| let keys_buffer = Buffer::from(keys.to_byte_slice()); | ||
|
|
||
| let value_data = ArrayData::builder(value_type.clone()) | ||
| .len(8) | ||
| .add_buffer(Buffer::from(value.to_byte_slice())) | ||
| .build() | ||
| .unwrap(); | ||
|
|
||
| let dict_data_type = DataType::Dictionary(Box::new(key_type), Box::new(value_type)); | ||
| ArrayData::builder(dict_data_type.clone()) | ||
| .len(3) | ||
| .add_buffer(keys_buffer) | ||
| .add_child_data(value_data.clone()) | ||
| .build() | ||
| .unwrap() | ||
| } | ||
|
|
||
| // arrays containing concanated numeric character from 0 to 255 | ||
| // like ["0","1",..,"255"] | ||
| fn make_numeric_string_array(numbers: Vec<u32>) -> ArrayData { | ||
| let values = numbers.iter().map(|i| i.to_string()).collect::<String>(); | ||
| let mut acc = 0; | ||
|
|
||
| let offset_iter = numbers | ||
| .iter() | ||
| .map(|i| if *i == 0 { 1 } else { i.ilog10() + 1 }) | ||
| .map(|length| { | ||
| acc += length; | ||
| acc | ||
| }); | ||
|
|
||
| let offsets = once(0).chain(offset_iter).collect::<Vec<_>>(); | ||
| ArrayData::builder(DataType::Utf8) | ||
| .len(numbers.len()) | ||
| .add_buffer(Buffer::from_slice_ref(offsets)) | ||
| .add_buffer(Buffer::from_slice_ref(values.as_bytes())) | ||
| .build() | ||
| .unwrap() | ||
| } | ||
|
|
||
| #[test] | ||
| fn merge_string_value_dictionary() { | ||
| let arr1 = create_dictionary_from_value_data( | ||
| (0u8..=127).collect(), | ||
| make_numeric_string_array((0..=127).collect()), | ||
| DataType::UInt8, | ||
| ); | ||
| let arr1_clone = create_dictionary_from_value_data( | ||
| (0u8..=127).collect(), | ||
| make_numeric_string_array((0..=127).collect()), | ||
| DataType::UInt8, | ||
| ); | ||
| let arr2 = create_dictionary_from_value_data( | ||
| (0u8..=127).collect(), | ||
| make_numeric_string_array((128..=255).collect()), | ||
| DataType::UInt8, | ||
| ); | ||
| // all possible values from arr1 and arr2 require keysize > 131072 | ||
| // which overflows for uint16 | ||
|
|
||
| let (extends, merged_value_arr) = merge_dictionaries( | ||
| &DataType::UInt8, | ||
| &DataType::Utf8, | ||
| &[&arr1, &arr2, &arr1_clone], | ||
| ) | ||
| .unwrap(); | ||
|
|
||
| // this array is used as value array for the new dictionary | ||
| let expected_new_value = make_numeric_string_array((0..=255).collect()); | ||
| assert!(expected_new_value.eq(&merged_value_arr)); | ||
|
|
||
| let [buffer1, buffer2] = new_buffers(arr1.data_type(), 256); | ||
| let mut data = _MutableArrayData { | ||
| data_type: arr1.data_type().clone(), | ||
| len: 0, | ||
| null_count: 0, | ||
| null_buffer: None, | ||
| buffer1, | ||
| buffer2, | ||
| child_data: vec![], | ||
| }; | ||
|
|
||
| // concat keys [0..127] [128..255] [0..128] | ||
| for (index, extend) in extends.iter().enumerate() { | ||
| extend(&mut data, index, 0, 128) | ||
| } | ||
| // key buffer after calling extends closure is also correct | ||
| let expected_key_raw_buffer = (0u8..=255).chain(0u8..=127).collect::<Vec<_>>(); | ||
| assert_eq!(data.buffer1.as_slice(), &expected_key_raw_buffer); | ||
| } | ||
|
|
||
| #[test] | ||
| fn total_distinct_keys_in_input_arrays_greater_than_key_size() { | ||
| // all possible values from arr1 and arr2 require keysize > 131072 | ||
| // which overflows for uint16 | ||
| let arr1 = create_dictionary( | ||
| (0u16..=65535).collect(), | ||
| (0u32..=65535).collect(), | ||
| DataType::UInt16, | ||
| DataType::UInt32, | ||
| ); | ||
| let arr2 = create_dictionary( | ||
| (0u16..=65535).collect(), | ||
| (65536u32..=131071).collect(), | ||
| DataType::UInt16, | ||
| DataType::UInt32, | ||
| ); | ||
| assert!(matches!( | ||
| merge_dictionaries(&DataType::UInt16, &DataType::UInt32, &[&arr1, &arr2]), | ||
| Err(ArrowError::DictionaryKeyOverflowError), | ||
| )); | ||
| } | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we please leave some comments about what this function does and why it is needed? (aka explain the overflow backup case)