Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 69 additions & 17 deletions src/rewrite/normal_form.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ impl SpjNormalForm {
/// Stores information on filters from a Select-Project-Join plan.
#[derive(Debug, Clone)]
struct Predicate {
/// Full table schema, including all possible columns.
schema: DFSchema,
/// List of column equivalence classes.
eq_classes: Vec<ColumnEquivalenceClass>,
Expand All @@ -350,10 +351,14 @@ impl Predicate {
let mut schema = DFSchema::empty();
plan.apply(|plan| {
if let LogicalPlan::TableScan(scan) = plan {
let new_schema = DFSchema::try_from_qualified_schema(
scan.table_name.clone(),
scan.source.schema().as_ref(),
)?;
schema = if schema.fields().is_empty() {
(*scan.projected_schema).clone()
new_schema
} else {
schema.join(&scan.projected_schema)?
schema.join(&new_schema)?
}
}

Expand All @@ -371,7 +376,13 @@ impl Predicate {
// Collect all referenced columns
plan.apply(|plan| {
if let LogicalPlan::TableScan(scan) = plan {
for (i, (table_ref, field)) in scan.projected_schema.iter().enumerate() {
for (i, (table_ref, field)) in DFSchema::try_from_qualified_schema(
scan.table_name.clone(),
scan.source.schema().as_ref(),
)?
.iter()
.enumerate()
{
let column = Column::new(table_ref.cloned(), field.name());
let data_type = field.data_type();
new.eq_classes
Expand Down Expand Up @@ -948,17 +959,47 @@ fn get_table_scan_columns(scan: &TableScan) -> Result<Vec<Column>> {
#[cfg(test)]
mod test {
use arrow::compute::concat_batches;
use datafusion::{datasource::provider_as_source, prelude::SessionContext};
use datafusion::{
datasource::provider_as_source,
prelude::{SessionConfig, SessionContext},
};
use datafusion_common::{DataFusionError, Result};
use datafusion_sql::TableReference;
use tempfile::tempdir;

use super::SpjNormalForm;

async fn setup() -> Result<SessionContext> {
let ctx = SessionContext::new();
let ctx = SessionContext::new_with_config(
SessionConfig::new()
.set_bool("datafusion.execution.parquet.pushdown_filters", true)
.set_bool("datafusion.explain.logical_plan_only", true),
);

let t1_path = tempdir()?;

// Create external table to exercise parquet filter pushdown.
// This will put the filters directly inside the `TableScan` node.
// This is important because `TableScan` can have filters on
// columns not in its own output.
ctx.sql(&format!(
"
CREATE EXTERNAL TABLE t1 (
column1 VARCHAR,
column2 BIGINT,
column3 CHAR
)
STORED AS PARQUET
LOCATION '{}'",
t1_path.path().to_string_lossy()
))
.await
.map_err(|e| e.context("setup `t1` table"))?
.collect()
.await?;

ctx.sql(
"CREATE TABLE t1 AS VALUES
"INSERT INTO t1 VALUES
('2021', 3, 'A'),
('2022', 4, 'B'),
('2023', 5, 'C')",
Expand All @@ -980,8 +1021,7 @@ mod test {
o_orderdate DATE,
p_name VARCHAR,
p_partkey INT
)
",
)",
)
.await
.map_err(|e| e.context("parse `example` table ddl"))?
Expand Down Expand Up @@ -1014,6 +1054,15 @@ mod test {
let query_plan = context.sql(case.query).await?.into_optimized_plan()?;
let query_normal_form = SpjNormalForm::new(&query_plan)?;

for plan in [&base_plan, &query_plan] {
context
.execute_logical_plan(plan.clone())
.await?
.explain(false, false)?
.show()
.await?;
}

let table_ref = TableReference::bare("mv");
let rewritten = query_normal_form
.rewrite_from(
Expand All @@ -1025,16 +1074,14 @@ mod test {
"expected rewrite to succeed".to_string(),
))?;

assert_eq!(rewritten.schema().as_ref(), query_plan.schema().as_ref());
context
.execute_logical_plan(rewritten.clone())
.await?
.explain(false, false)?
.show()
.await?;

for plan in [&base_plan, &query_plan, &rewritten] {
context
.execute_logical_plan(plan.clone())
.await?
.explain(false, false)?
.show()
.await?;
}
assert_eq!(rewritten.schema().as_ref(), query_plan.schema().as_ref());

let expected = concat_batches(
&query_plan.schema().as_ref().clone().into(),
Expand Down Expand Up @@ -1133,6 +1180,11 @@ mod test {
l_quantity*l_extendedprice > 100
",
},
TestCase {
name: "naked table scan with pushed down filters",
base: "SELECT column1 FROM t1 WHERE column2 <= 3",
query: "SELECT FROM t1 WHERE column2 <= 3",
},
];

for case in cases {
Expand Down
Loading