Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 138 additions & 86 deletions arrow-json/src/reader/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,18 @@ use arrow_array::builder::BooleanBufferBuilder;
use arrow_buffer::buffer::NullBuffer;
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, Fields};
use std::collections::HashMap;

pub struct StructArrayDecoder {
data_type: DataType,
decoders: Vec<Box<dyn ArrayDecoder>>,
strict_mode: bool,
is_nullable: bool,
struct_mode: StructMode,
field_name_to_index: Option<HashMap<String, usize>>,
/// Reusable buffer of tape positions indexed as `[field_idx * row_count + row_idx]`.
/// A value of 0 indicates the field is absent for that row.
field_tape_positions: Vec<u32>,
}

impl StructArrayDecoder {
Expand All @@ -38,131 +43,162 @@ impl StructArrayDecoder {
is_nullable: bool,
struct_mode: StructMode,
) -> Result<Self, ArrowError> {
let decoders = struct_fields(&data_type)
.iter()
.map(|f| {
// If this struct nullable, need to permit nullability in child array
// StructArrayDecoder::decode verifies that if the child is not nullable
// it doesn't contain any nulls not masked by its parent
let nullable = f.is_nullable() || is_nullable;
make_decoder(
f.data_type().clone(),
coerce_primitive,
strict_mode,
nullable,
struct_mode,
)
})
.collect::<Result<Vec<_>, ArrowError>>()?;
let (decoders, field_name_to_index) = {
let fields = struct_fields(&data_type);
let decoders = fields
.iter()
.map(|f| {
// If this struct nullable, need to permit nullability in child array
// StructArrayDecoder::decode verifies that if the child is not nullable
// it doesn't contain any nulls not masked by its parent
let nullable = f.is_nullable() || is_nullable;
make_decoder(
f.data_type().clone(),
coerce_primitive,
strict_mode,
nullable,
struct_mode,
)
})
.collect::<Result<Vec<_>, ArrowError>>()?;
let field_name_to_index = if struct_mode == StructMode::ObjectOnly {
build_field_index(fields)
} else {
None
};
(decoders, field_name_to_index)
};

Ok(Self {
data_type,
decoders,
strict_mode,
is_nullable,
struct_mode,
field_name_to_index,
field_tape_positions: Vec::new(),
})
}
}

impl ArrayDecoder for StructArrayDecoder {
fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
let fields = struct_fields(&self.data_type);
let mut child_pos: Vec<_> = (0..fields.len()).map(|_| vec![0; pos.len()]).collect();

let row_count = pos.len();
let field_count = fields.len();
let total_len = field_count.checked_mul(row_count).ok_or_else(|| {
ArrowError::JsonError(format!(
"StructArrayDecoder child position buffer size overflow for rows={row_count} fields={field_count}"
))
})?;
self.field_tape_positions.clear();
self.field_tape_positions.resize(total_len, 0);
let mut nulls = self
.is_nullable
.then(|| BooleanBufferBuilder::new(pos.len()));

// We avoid having the match on self.struct_mode inside the hot loop for performance
// TODO: Investigate how to extract duplicated logic.
match self.struct_mode {
StructMode::ObjectOnly => {
for (row, p) in pos.iter().enumerate() {
let end_idx = match (tape.get(*p), nulls.as_mut()) {
(TapeElement::StartObject(end_idx), None) => end_idx,
(TapeElement::StartObject(end_idx), Some(nulls)) => {
nulls.append(true);
end_idx
}
(TapeElement::Null, Some(nulls)) => {
nulls.append(false);
continue;
}
(_, _) => return Err(tape.error(*p, "{")),
};

let mut cur_idx = *p + 1;
while cur_idx < end_idx {
// Read field name
let field_name = match tape.get(cur_idx) {
TapeElement::String(s) => tape.get_string(s),
_ => return Err(tape.error(cur_idx, "field name")),
{
let child_pos = self.field_tape_positions.as_mut_slice();
// We avoid having the match on self.struct_mode inside the hot loop for performance
// TODO: Investigate how to extract duplicated logic.
match self.struct_mode {
StructMode::ObjectOnly => {
for (row, p) in pos.iter().enumerate() {
let end_idx = match (tape.get(*p), nulls.as_mut()) {
(TapeElement::StartObject(end_idx), None) => end_idx,
(TapeElement::StartObject(end_idx), Some(nulls)) => {
nulls.append(true);
end_idx
}
(TapeElement::Null, Some(nulls)) => {
nulls.append(false);
continue;
}
(_, _) => return Err(tape.error(*p, "{")),
};

// Update child pos if match found
match fields.iter().position(|x| x.name() == field_name) {
Some(field_idx) => child_pos[field_idx][row] = cur_idx + 1,
None => {
if self.strict_mode {
return Err(ArrowError::JsonError(format!(
"column '{field_name}' missing from schema",
)));
let mut cur_idx = *p + 1;
while cur_idx < end_idx {
// Read field name
let field_name = match tape.get(cur_idx) {
TapeElement::String(s) => tape.get_string(s),
_ => return Err(tape.error(cur_idx, "field name")),
};

// Update child pos if match found
let field_idx = match &self.field_name_to_index {
Some(map) => map.get(field_name).copied(),
None => fields.iter().position(|x| x.name() == field_name),
};
match field_idx {
Some(field_idx) => {
child_pos[field_idx * row_count + row] = cur_idx + 1;
}
None => {
if self.strict_mode {
return Err(ArrowError::JsonError(format!(
"column '{field_name}' missing from schema",
)));
}
}
}
// Advance to next field
cur_idx = tape.next(cur_idx + 1, "field value")?;
}
// Advance to next field
cur_idx = tape.next(cur_idx + 1, "field value")?;
}
}
}
StructMode::ListOnly => {
for (row, p) in pos.iter().enumerate() {
let end_idx = match (tape.get(*p), nulls.as_mut()) {
(TapeElement::StartList(end_idx), None) => end_idx,
(TapeElement::StartList(end_idx), Some(nulls)) => {
nulls.append(true);
end_idx
}
(TapeElement::Null, Some(nulls)) => {
nulls.append(false);
continue;
}
(_, _) => return Err(tape.error(*p, "[")),
};
StructMode::ListOnly => {
for (row, p) in pos.iter().enumerate() {
let end_idx = match (tape.get(*p), nulls.as_mut()) {
(TapeElement::StartList(end_idx), None) => end_idx,
(TapeElement::StartList(end_idx), Some(nulls)) => {
nulls.append(true);
end_idx
}
(TapeElement::Null, Some(nulls)) => {
nulls.append(false);
continue;
}
(_, _) => return Err(tape.error(*p, "[")),
};

let mut cur_idx = *p + 1;
let mut entry_idx = 0;
while cur_idx < end_idx {
if entry_idx >= fields.len() {
let mut cur_idx = *p + 1;
let mut entry_idx = 0;
while cur_idx < end_idx {
if entry_idx >= fields.len() {
return Err(ArrowError::JsonError(format!(
"found extra columns for {} fields",
fields.len()
)));
}
child_pos[entry_idx * row_count + row] = cur_idx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 this is a nice way to avoid allocations

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also be part of the dedicated struct

entry_idx += 1;
// Advance to next field
cur_idx = tape.next(cur_idx, "field value")?;
}
if entry_idx != fields.len() {
return Err(ArrowError::JsonError(format!(
"found extra columns for {} fields",
"found {} columns for {} fields",
entry_idx,
fields.len()
)));
}
child_pos[entry_idx][row] = cur_idx;
entry_idx += 1;
// Advance to next field
cur_idx = tape.next(cur_idx, "field value")?;
}
if entry_idx != fields.len() {
return Err(ArrowError::JsonError(format!(
"found {} columns for {} fields",
entry_idx,
fields.len()
)));
}
}
}
}

let child_pos = self.field_tape_positions.as_slice();
let child_data = self
.decoders
.iter_mut()
.zip(child_pos)
.enumerate()
.zip(fields)
.map(|((d, pos), f)| {
d.decode(tape, &pos).map_err(|e| match e {
.map(|((field_idx, d), f)| {
let start = field_idx * row_count;
let end = start + row_count;
let pos = &child_pos[start..end];
Comment on lines +198 to +200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to extract the field_tape_positions into another private struct that expose the api and hide the implementation detail

d.decode(tape, pos).map_err(|e| match e {
ArrowError::JsonError(s) => {
ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
}
Expand Down Expand Up @@ -205,3 +241,19 @@ fn struct_fields(data_type: &DataType) -> &Fields {
_ => unreachable!(),
}
}

fn build_field_index(fields: &Fields) -> Option<HashMap<String, usize>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: Do lifetimes coincide so that we could return Option<HashMap<&str, usize>> instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the lifetimes do coincide. we can use HashMap<&'a str, usize> by taking fields: &'a Fields as a parameter, which avoids the self-referential struct problem. However, this would require threading the lifetime parameter <'a> through the entire decoder system across many files. Since the lookup performance is identical, I don’t think it’s worth the added complexity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it would be a good follow on PR

const FIELD_INDEX_LINEAR_THRESHOLD: usize = 16;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment about that const please, why that number

if fields.len() < FIELD_INDEX_LINEAR_THRESHOLD {
return None;
}

let mut map = HashMap::with_capacity(fields.len());
for (idx, field) in fields.iter().enumerate() {
let name = field.name();
if !map.contains_key(name) {
map.insert(name.to_string(), idx);
}
}
Some(map)
}
Loading