Skip to content

Commit 00e09ac

Browse files
zstansuibianwanwank
authored andcommitted
[CALCITE-5638] Columns trimmer need to consider sub queries
1 parent 1d4b1fd commit 00e09ac

File tree

5 files changed

+210
-8
lines changed

5 files changed

+210
-8
lines changed

core/src/main/java/org/apache/calcite/plan/RelOptUtil.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4552,6 +4552,37 @@ private void acceptFields(final List<RelDataTypeField> fields) {
45524552
}
45534553
}
45544554

4555+
/** Extension of {@link RelOptUtil.InputFinder} with optional subquery lookup. */
4556+
public static class SubQueryAwareInputFinder extends RelOptUtil.InputFinder {
4557+
boolean visitSubQuery;
4558+
4559+
public SubQueryAwareInputFinder(@Nullable Set<RelDataTypeField> extraFields,
4560+
boolean visitSubQuery) {
4561+
super(extraFields, ImmutableBitSet.builder());
4562+
this.visitSubQuery = visitSubQuery;
4563+
}
4564+
4565+
@Override public Void visitSubQuery(RexSubQuery subQuery) {
4566+
if (visitSubQuery && subQuery.getKind() == SqlKind.SCALAR_QUERY) {
4567+
subQuery.rel.accept(new RelHomogeneousShuttle() {
4568+
@Override public RelNode visit(LogicalProject project) {
4569+
project.getProjects().forEach(r -> r.accept(SubQueryAwareInputFinder.this));
4570+
return super.visit(project);
4571+
}
4572+
4573+
@Override public RelNode visit(LogicalFilter filter) {
4574+
filter.getCondition().accept(SubQueryAwareInputFinder.this);
4575+
return super.visit(filter);
4576+
}
4577+
});
4578+
4579+
return null;
4580+
} else {
4581+
return super.visitSubQuery(subQuery);
4582+
}
4583+
}
4584+
}
4585+
45554586
/**
45564587
* Visitor which builds a bitmap of the inputs used by an expression.
45574588
*/

core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,24 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) {
910910
int nFieldsLeft = join.getLeft().getRowType().getFieldCount();
911911
int nFieldsRight = join.getRight().getRowType().getFieldCount();
912912

913+
// Correlation columns should also be considered.
914+
// For example:
915+
// LogicalJoin
916+
// left right
917+
// | |
918+
// LogicalProject.NONE.[0, 1] LogicalValues.NONE.[0]
919+
// RecordType(INTEGER DEPTNO, CHAR(11) DNAME) RecordType(INTEGER DEPTNO)
920+
//
921+
// and subquery: $SCALAR_QUERY with correlate
922+
// LogicalProject(DEPTNO=[$1])
923+
// LogicalFilter(condition=[=(CAST($0):CHAR(11) NOT NULL, $cor0.DNAME)])
924+
//
925+
// In such a case $cor0.DNAME need to be accounted as input form left side.
926+
final Set<CorrelationId> variablesSet = RelOptUtil.getVariablesUsed(e.rel);
927+
for (CorrelationId id : variablesSet) {
928+
ImmutableBitSet requiredColumns = RelOptUtil.correlationColumns(id, e.rel);
929+
inputSet = ImmutableBitSet.union(ImmutableList.of(requiredColumns, inputSet));
930+
}
913931

914932
boolean inputIntersectsLeftSide = inputSet.intersects(ImmutableBitSet.range(0, nFieldsLeft));
915933
boolean inputIntersectsRightSide =
@@ -922,7 +940,6 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) {
922940
return;
923941
}
924942

925-
final Set<CorrelationId> variablesSet = RelOptUtil.getVariablesUsed(e.rel);
926943
if (inputIntersectsLeftSide) {
927944
builder.push(join.getLeft());
928945

core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java

Lines changed: 116 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.calcite.rel.RelCollations;
2424
import org.apache.calcite.rel.RelDistribution;
2525
import org.apache.calcite.rel.RelFieldCollation;
26+
import org.apache.calcite.rel.RelHomogeneousShuttle;
2627
import org.apache.calcite.rel.RelNode;
2728
import org.apache.calcite.rel.core.Aggregate;
2829
import org.apache.calcite.rel.core.AggregateCall;
@@ -61,6 +62,7 @@
6162
import org.apache.calcite.rex.RexSubQuery;
6263
import org.apache.calcite.rex.RexUtil;
6364
import org.apache.calcite.rex.RexVisitor;
65+
import org.apache.calcite.rex.RexVisitorImpl;
6466
import org.apache.calcite.sql.SqlExplainFormat;
6567
import org.apache.calcite.sql.SqlExplainLevel;
6668
import org.apache.calcite.sql.SqlKind;
@@ -78,6 +80,8 @@
7880
import org.apache.calcite.util.mapping.Mappings;
7981

8082
import com.google.common.collect.ImmutableList;
83+
import com.google.common.collect.ImmutableSet;
84+
import com.google.common.collect.Iterables;
8185

8286
import org.checkerframework.checker.nullness.qual.Nullable;
8387

@@ -473,6 +477,77 @@ public TrimResult trimFields(
473477
return result(newCalc, mapping, calc);
474478
}
475479

480+
/**
481+
* Shuttle that finds all {@link TableScan}s inside a given {@link RelNode}.
482+
*/
483+
private static class TableScanCollector extends RelHomogeneousShuttle {
484+
private ImmutableSet.Builder<List<String>> builder = ImmutableSet.builder();
485+
486+
/** Qualified names. */
487+
Set<List<String>> tables() {
488+
return builder.build();
489+
}
490+
491+
@Override public RelNode visit(TableScan scan) {
492+
builder.add(scan.getTable().getQualifiedName());
493+
return super.visit(scan);
494+
}
495+
}
496+
497+
/**
498+
* Shuttle that finds all {@link TableScan}`s inside a given {@link RexNode}.
499+
*/
500+
private static class InputTablesVisitor extends RexVisitorImpl<Void> {
501+
private ImmutableSet.Builder<List<String>> builder = ImmutableSet.builder();
502+
503+
protected InputTablesVisitor() {
504+
super(false);
505+
}
506+
507+
/** Qualified names. */
508+
Set<List<String>> tables() {
509+
return builder.build();
510+
}
511+
512+
@Override public Void visitSubQuery(RexSubQuery subQuery) {
513+
if (subQuery.getKind() == SqlKind.SCALAR_QUERY) {
514+
subQuery.rel.accept(new RelHomogeneousShuttle() {
515+
@Override public RelNode visit(TableScan scan) {
516+
builder.add(scan.getTable().getQualifiedName());
517+
return super.visit(scan);
518+
}
519+
});
520+
}
521+
return null;
522+
}
523+
}
524+
525+
private boolean inputContainsSubQueryTables(Project project, RelNode input) {
526+
InputTablesVisitor inputSubQueryTablesCollector = new InputTablesVisitor();
527+
528+
RexUtil.apply(inputSubQueryTablesCollector, project.getProjects(), null);
529+
530+
Set<List<String>> subQueryTables = inputSubQueryTablesCollector.tables();
531+
532+
assert subQueryTables.isEmpty() || subQueryTables.size() == 1
533+
: "unexpected different tables in subquery: " + subQueryTables;
534+
535+
TableScanCollector inputTablesCollector = new TableScanCollector();
536+
input.accept(inputTablesCollector);
537+
538+
Set<List<String>> inputTables = inputTablesCollector.tables();
539+
// Check for input and subquery tables intersection.
540+
if (!subQueryTables.isEmpty()) {
541+
for (List<String> t : inputTables) {
542+
if (t.equals(Iterables.getOnlyElement(subQueryTables))) {
543+
return true;
544+
}
545+
}
546+
}
547+
548+
return false;
549+
}
550+
476551
/**
477552
* Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
478553
* {@link org.apache.calcite.rel.logical.LogicalProject}.
@@ -488,18 +563,24 @@ public TrimResult trimFields(
488563
// Which fields are required from the input?
489564
final Set<RelDataTypeField> inputExtraFields =
490565
new LinkedHashSet<>(extraFields);
566+
567+
// Collect all the SubQueries in the projection list.
568+
List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project);
569+
// Get all the correlationIds present in the SubQueries
570+
Set<CorrelationId> correlationIds = RelOptUtil.getVariablesUsed(subQueries);
571+
// Subquery lookup is required.
572+
boolean subQueryLookUp =
573+
!correlationIds.isEmpty() && inputContainsSubQueryTables(project, input);
574+
491575
RelOptUtil.InputFinder inputFinder =
492-
new RelOptUtil.InputFinder(inputExtraFields);
576+
new RelOptUtil.SubQueryAwareInputFinder(inputExtraFields, subQueryLookUp);
577+
493578
for (Ord<RexNode> ord : Ord.zip(project.getProjects())) {
494579
if (fieldsUsed.get(ord.i)) {
495580
ord.e.accept(inputFinder);
496581
}
497582
}
498583

499-
// Collect all the SubQueries in the projection list.
500-
List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project);
501-
// Get all the correlationIds present in the SubQueries
502-
Set<CorrelationId> correlationIds = RelOptUtil.getVariablesUsed(subQueries);
503584
ImmutableBitSet requiredColumns = ImmutableBitSet.of();
504585
if (!correlationIds.isEmpty()) {
505586
assert correlationIds.size() == 1;
@@ -547,6 +628,14 @@ public TrimResult trimFields(
547628
if (fieldsUsed.get(ord.i)) {
548629
mapping.set(ord.i, newProjects.size());
549630
RexNode newProjectExpr = ord.e.accept(shuttle);
631+
// Subquery need to be remapped
632+
if (newProjectExpr instanceof RexSubQuery
633+
&& newProjectExpr.getKind() == SqlKind.SCALAR_QUERY
634+
&& !correlationIds.isEmpty()) {
635+
newProjectExpr =
636+
changeCorrelateReferences((RexSubQuery) newProjectExpr,
637+
Iterables.getOnlyElement(correlationIds), newInput.getRowType(), inputMapping);
638+
}
550639
newProjects.add(newProjectExpr);
551640
}
552641
}
@@ -561,6 +650,24 @@ public TrimResult trimFields(
561650
return result(newProject, mapping, project);
562651
}
563652

653+
private RexNode changeCorrelateReferences(
654+
RexSubQuery node,
655+
CorrelationId corrId,
656+
RelDataType rowType,
657+
Mapping inputMapping) {
658+
assert node.getKind() == SqlKind.SCALAR_QUERY : "Expected a SCALAR_QUERY, found "
659+
+ node.getKind();
660+
RelNode subQuery = node.rel;
661+
RexBuilder rexBuilder = relBuilder.getRexBuilder();
662+
663+
RexCorrelVariableMapShuttle rexVisitor =
664+
new RexCorrelVariableMapShuttle(corrId, rowType, inputMapping, rexBuilder);
665+
RelNode newSubQuery =
666+
subQuery.accept(new RexRewritingRelShuttle(rexVisitor));
667+
668+
return RexSubQuery.scalar(newSubQuery);
669+
}
670+
564671
/** Creates a project with a dummy column, to protect the parts of the system
565672
* that cannot handle a relational expression with no columns.
566673
*
@@ -1437,9 +1544,12 @@ static class RexCorrelVariableMapShuttle extends RexShuttle {
14371544
(RexCorrelVariable) fieldAccess.getReferenceExpr();
14381545
if (referenceExpr.id.equals(correlationId)) {
14391546
int oldIndex = fieldAccess.getField().getIndex();
1547+
int newIndex = mapping.getTarget(oldIndex);
1548+
if (newIndex == oldIndex) {
1549+
return super.visitFieldAccess(fieldAccess);
1550+
}
14401551
RexNode newCorrel =
14411552
rexBuilder.makeCorrel(newCorrelRowType, referenceExpr.id);
1442-
int newIndex = mapping.getTarget(oldIndex);
14431553
return rexBuilder.makeFieldAccess(newCorrel, newIndex);
14441554
}
14451555
}

core/src/test/java/org/apache/calcite/test/enumerable/EnumerableCorrelateTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,24 @@ class EnumerableCorrelateTest {
168168
"empid=200");
169169
}
170170

171+
/** Test case for
172+
* <a href="https://issues.apache.org/jira/browse/CALCITE-5638">[CALCITE-5638]
173+
* Columns trimmer need to consider sub queries</a>.
174+
*/
175+
@Test void complexNestedCorrelatedSubquery() {
176+
String sql = "SELECT empid, deptno, (SELECT count(*) FROM emps AS x "
177+
+ "WHERE x.salary>emps.salary and x.deptno<emps.deptno) FROM emps "
178+
+ "WHERE empid<salary ORDER BY 1,2,3";
179+
180+
tester(false, new HrSchema())
181+
.query(sql)
182+
.returnsOrdered(
183+
"empid=100; deptno=10; EXPR$2=0",
184+
"empid=110; deptno=10; EXPR$2=0",
185+
"empid=150; deptno=10; EXPR$2=0",
186+
"empid=200; deptno=20; EXPR$2=2");
187+
}
188+
171189
/** Test case for
172190
* <a href="https://issues.apache.org/jira/browse/CALCITE-2920">[CALCITE-2920]
173191
* RelBuilder: new method to create an anti-join</a>. */

core/src/test/resources/sql/sub-query.iq

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,32 @@ select * from dept where deptno not in (select deptno from emp);
6363
(0 rows)
6464

6565
!ok
66+
67+
# [CALCITE-5638] Assertion Failure during planning correlated query
68+
SELECT "hr"."emps"."empid", "hr"."emps"."deptno",
69+
(SELECT count(*) FROM "hr"."emps" AS x WHERE x."salary">"hr"."emps"."salary" AND x."deptno"<"hr"."emps"."deptno")
70+
FROM "hr"."emps"
71+
WHERE "hr"."emps"."empid"<"hr"."emps"."salary"
72+
ORDER BY 1,2,3;
73+
empid | deptno | EXPR$2
74+
-------+--------+--------
75+
100 | 10 | 0
76+
110 | 10 | 0
77+
150 | 10 | 0
78+
200 | 20 | 2
79+
(4 rows)
80+
81+
!ok
82+
83+
# [CALCITE-5638] Assertion Failure during planning correlated query
84+
SELECT t1.deptno FROM dept AS t0 JOIN emp AS t1 ON
85+
(t1.deptno = (SELECT inner_t1.deptno FROM emp AS inner_t1 WHERE inner_t1.ENAME = t0.DNAME));
86+
DEPTNO
87+
--------
88+
(0 rows)
89+
90+
!ok
91+
6692
select deptno, deptno in (select deptno from emp) from dept;
6793
DEPTNO | EXPR$1
6894
--------+--------
@@ -873,7 +899,7 @@ EnumerableCalc(expr#0..2=[{inputs}], proj#0..1=[{exprs}])
873899
EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
874900
EnumerableValues(tuples=[[{ 1, 2 }]])
875901
EnumerableAggregate(group=[{0}])
876-
EnumerableCalc(expr#0..7=[{inputs}], expr#8=[true], expr#9=[CAST($t7):INTEGER], expr#10=[$cor0], expr#11=[$t10.EXPR$0], expr#12=[=($t9, $t11)], i=[$t8], $condition=[$t12])
902+
EnumerableCalc(expr#0..7=[{inputs}], expr#8=[true], expr#9=[CAST($t7):INTEGER], expr#10=[$cor0], expr#11=[$t10.A], expr#12=[=($t9, $t11)], i=[$t8], $condition=[$t12])
877903
EnumerableTableScan(table=[[scott, EMP]])
878904
!plan
879905

0 commit comments

Comments
 (0)