-
Notifications
You must be signed in to change notification settings - Fork 1.1k
perf: improve field indexing in JSON StructArrayDecoder (1.7x speed up) #9086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
726d1bf
5109576
e593630
06ded8b
df9e710
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,13 +21,18 @@ use arrow_array::builder::BooleanBufferBuilder; | |
| use arrow_buffer::buffer::NullBuffer; | ||
| use arrow_data::{ArrayData, ArrayDataBuilder}; | ||
| use arrow_schema::{ArrowError, DataType, Fields}; | ||
| use std::collections::HashMap; | ||
|
|
||
| pub struct StructArrayDecoder { | ||
| data_type: DataType, | ||
| decoders: Vec<Box<dyn ArrayDecoder>>, | ||
| strict_mode: bool, | ||
| is_nullable: bool, | ||
| struct_mode: StructMode, | ||
| field_name_to_index: Option<HashMap<String, usize>>, | ||
| /// Reusable buffer of tape positions indexed as `[field_idx * row_count + row_idx]`. | ||
| /// A value of 0 indicates the field is absent for that row. | ||
| field_tape_positions: Vec<u32>, | ||
| } | ||
|
|
||
| impl StructArrayDecoder { | ||
|
|
@@ -38,131 +43,162 @@ impl StructArrayDecoder { | |
| is_nullable: bool, | ||
| struct_mode: StructMode, | ||
| ) -> Result<Self, ArrowError> { | ||
| let decoders = struct_fields(&data_type) | ||
| .iter() | ||
| .map(|f| { | ||
| // If this struct nullable, need to permit nullability in child array | ||
| // StructArrayDecoder::decode verifies that if the child is not nullable | ||
| // it doesn't contain any nulls not masked by its parent | ||
| let nullable = f.is_nullable() || is_nullable; | ||
| make_decoder( | ||
| f.data_type().clone(), | ||
| coerce_primitive, | ||
| strict_mode, | ||
| nullable, | ||
| struct_mode, | ||
| ) | ||
| }) | ||
| .collect::<Result<Vec<_>, ArrowError>>()?; | ||
| let (decoders, field_name_to_index) = { | ||
| let fields = struct_fields(&data_type); | ||
| let decoders = fields | ||
| .iter() | ||
| .map(|f| { | ||
| // If this struct nullable, need to permit nullability in child array | ||
| // StructArrayDecoder::decode verifies that if the child is not nullable | ||
| // it doesn't contain any nulls not masked by its parent | ||
| let nullable = f.is_nullable() || is_nullable; | ||
| make_decoder( | ||
| f.data_type().clone(), | ||
| coerce_primitive, | ||
| strict_mode, | ||
| nullable, | ||
| struct_mode, | ||
| ) | ||
| }) | ||
| .collect::<Result<Vec<_>, ArrowError>>()?; | ||
| let field_name_to_index = if struct_mode == StructMode::ObjectOnly { | ||
| build_field_index(fields) | ||
| } else { | ||
| None | ||
| }; | ||
| (decoders, field_name_to_index) | ||
| }; | ||
|
|
||
| Ok(Self { | ||
| data_type, | ||
| decoders, | ||
| strict_mode, | ||
| is_nullable, | ||
| struct_mode, | ||
| field_name_to_index, | ||
| field_tape_positions: Vec::new(), | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| impl ArrayDecoder for StructArrayDecoder { | ||
| fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> { | ||
| let fields = struct_fields(&self.data_type); | ||
| let mut child_pos: Vec<_> = (0..fields.len()).map(|_| vec![0; pos.len()]).collect(); | ||
|
|
||
| let row_count = pos.len(); | ||
| let field_count = fields.len(); | ||
| let total_len = field_count.checked_mul(row_count).ok_or_else(|| { | ||
| ArrowError::JsonError(format!( | ||
| "StructArrayDecoder child position buffer size overflow for rows={row_count} fields={field_count}" | ||
| )) | ||
| })?; | ||
| self.field_tape_positions.clear(); | ||
| self.field_tape_positions.resize(total_len, 0); | ||
| let mut nulls = self | ||
| .is_nullable | ||
| .then(|| BooleanBufferBuilder::new(pos.len())); | ||
|
|
||
| // We avoid having the match on self.struct_mode inside the hot loop for performance | ||
| // TODO: Investigate how to extract duplicated logic. | ||
| match self.struct_mode { | ||
| StructMode::ObjectOnly => { | ||
| for (row, p) in pos.iter().enumerate() { | ||
| let end_idx = match (tape.get(*p), nulls.as_mut()) { | ||
| (TapeElement::StartObject(end_idx), None) => end_idx, | ||
| (TapeElement::StartObject(end_idx), Some(nulls)) => { | ||
| nulls.append(true); | ||
| end_idx | ||
| } | ||
| (TapeElement::Null, Some(nulls)) => { | ||
| nulls.append(false); | ||
| continue; | ||
| } | ||
| (_, _) => return Err(tape.error(*p, "{")), | ||
| }; | ||
|
|
||
| let mut cur_idx = *p + 1; | ||
| while cur_idx < end_idx { | ||
| // Read field name | ||
| let field_name = match tape.get(cur_idx) { | ||
| TapeElement::String(s) => tape.get_string(s), | ||
| _ => return Err(tape.error(cur_idx, "field name")), | ||
| { | ||
| let child_pos = self.field_tape_positions.as_mut_slice(); | ||
| // We avoid having the match on self.struct_mode inside the hot loop for performance | ||
| // TODO: Investigate how to extract duplicated logic. | ||
| match self.struct_mode { | ||
| StructMode::ObjectOnly => { | ||
| for (row, p) in pos.iter().enumerate() { | ||
| let end_idx = match (tape.get(*p), nulls.as_mut()) { | ||
| (TapeElement::StartObject(end_idx), None) => end_idx, | ||
| (TapeElement::StartObject(end_idx), Some(nulls)) => { | ||
| nulls.append(true); | ||
| end_idx | ||
| } | ||
| (TapeElement::Null, Some(nulls)) => { | ||
| nulls.append(false); | ||
| continue; | ||
| } | ||
| (_, _) => return Err(tape.error(*p, "{")), | ||
| }; | ||
|
|
||
| // Update child pos if match found | ||
| match fields.iter().position(|x| x.name() == field_name) { | ||
| Some(field_idx) => child_pos[field_idx][row] = cur_idx + 1, | ||
| None => { | ||
| if self.strict_mode { | ||
| return Err(ArrowError::JsonError(format!( | ||
| "column '{field_name}' missing from schema", | ||
| ))); | ||
| let mut cur_idx = *p + 1; | ||
| while cur_idx < end_idx { | ||
| // Read field name | ||
| let field_name = match tape.get(cur_idx) { | ||
| TapeElement::String(s) => tape.get_string(s), | ||
| _ => return Err(tape.error(cur_idx, "field name")), | ||
| }; | ||
|
|
||
| // Update child pos if match found | ||
| let field_idx = match &self.field_name_to_index { | ||
| Some(map) => map.get(field_name).copied(), | ||
| None => fields.iter().position(|x| x.name() == field_name), | ||
| }; | ||
| match field_idx { | ||
| Some(field_idx) => { | ||
| child_pos[field_idx * row_count + row] = cur_idx + 1; | ||
| } | ||
| None => { | ||
| if self.strict_mode { | ||
| return Err(ArrowError::JsonError(format!( | ||
| "column '{field_name}' missing from schema", | ||
| ))); | ||
| } | ||
| } | ||
| } | ||
| // Advance to next field | ||
| cur_idx = tape.next(cur_idx + 1, "field value")?; | ||
| } | ||
| // Advance to next field | ||
| cur_idx = tape.next(cur_idx + 1, "field value")?; | ||
| } | ||
| } | ||
| } | ||
| StructMode::ListOnly => { | ||
| for (row, p) in pos.iter().enumerate() { | ||
| let end_idx = match (tape.get(*p), nulls.as_mut()) { | ||
| (TapeElement::StartList(end_idx), None) => end_idx, | ||
| (TapeElement::StartList(end_idx), Some(nulls)) => { | ||
| nulls.append(true); | ||
| end_idx | ||
| } | ||
| (TapeElement::Null, Some(nulls)) => { | ||
| nulls.append(false); | ||
| continue; | ||
| } | ||
| (_, _) => return Err(tape.error(*p, "[")), | ||
| }; | ||
| StructMode::ListOnly => { | ||
| for (row, p) in pos.iter().enumerate() { | ||
| let end_idx = match (tape.get(*p), nulls.as_mut()) { | ||
| (TapeElement::StartList(end_idx), None) => end_idx, | ||
| (TapeElement::StartList(end_idx), Some(nulls)) => { | ||
| nulls.append(true); | ||
| end_idx | ||
| } | ||
| (TapeElement::Null, Some(nulls)) => { | ||
| nulls.append(false); | ||
| continue; | ||
| } | ||
| (_, _) => return Err(tape.error(*p, "[")), | ||
| }; | ||
|
|
||
| let mut cur_idx = *p + 1; | ||
| let mut entry_idx = 0; | ||
| while cur_idx < end_idx { | ||
| if entry_idx >= fields.len() { | ||
| let mut cur_idx = *p + 1; | ||
| let mut entry_idx = 0; | ||
| while cur_idx < end_idx { | ||
| if entry_idx >= fields.len() { | ||
| return Err(ArrowError::JsonError(format!( | ||
| "found extra columns for {} fields", | ||
| fields.len() | ||
| ))); | ||
| } | ||
| child_pos[entry_idx * row_count + row] = cur_idx; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can also be part of the dedicated struct |
||
| entry_idx += 1; | ||
| // Advance to next field | ||
| cur_idx = tape.next(cur_idx, "field value")?; | ||
| } | ||
| if entry_idx != fields.len() { | ||
| return Err(ArrowError::JsonError(format!( | ||
| "found extra columns for {} fields", | ||
| "found {} columns for {} fields", | ||
| entry_idx, | ||
| fields.len() | ||
| ))); | ||
| } | ||
| child_pos[entry_idx][row] = cur_idx; | ||
| entry_idx += 1; | ||
| // Advance to next field | ||
| cur_idx = tape.next(cur_idx, "field value")?; | ||
| } | ||
| if entry_idx != fields.len() { | ||
| return Err(ArrowError::JsonError(format!( | ||
| "found {} columns for {} fields", | ||
| entry_idx, | ||
| fields.len() | ||
| ))); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| let child_pos = self.field_tape_positions.as_slice(); | ||
| let child_data = self | ||
| .decoders | ||
| .iter_mut() | ||
| .zip(child_pos) | ||
| .enumerate() | ||
| .zip(fields) | ||
| .map(|((d, pos), f)| { | ||
| d.decode(tape, &pos).map_err(|e| match e { | ||
| .map(|((field_idx, d), f)| { | ||
| let start = field_idx * row_count; | ||
| let end = start + row_count; | ||
| let pos = &child_pos[start..end]; | ||
|
Comment on lines
+198
to
+200
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to extract the |
||
| d.decode(tape, pos).map_err(|e| match e { | ||
| ArrowError::JsonError(s) => { | ||
| ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name())) | ||
| } | ||
|
|
@@ -205,3 +241,19 @@ fn struct_fields(data_type: &DataType) -> &Fields { | |
| _ => unreachable!(), | ||
| } | ||
| } | ||
|
|
||
| fn build_field_index(fields: &Fields) -> Option<HashMap<String, usize>> { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: Do lifetimes coincide so that we could return
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the lifetimes do coincide. we can use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe it would be a good follow on PR |
||
| const FIELD_INDEX_LINEAR_THRESHOLD: usize = 16; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a comment about that const please, why that number |
||
| if fields.len() < FIELD_INDEX_LINEAR_THRESHOLD { | ||
| return None; | ||
| } | ||
|
|
||
| let mut map = HashMap::with_capacity(fields.len()); | ||
| for (idx, field) in fields.iter().enumerate() { | ||
| let name = field.name(); | ||
| if !map.contains_key(name) { | ||
| map.insert(name.to_string(), idx); | ||
| } | ||
| } | ||
| Some(map) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 this is a nice way to avoid allocations