2323import org .apache .calcite .rel .RelCollations ;
2424import org .apache .calcite .rel .RelDistribution ;
2525import org .apache .calcite .rel .RelFieldCollation ;
26+ import org .apache .calcite .rel .RelHomogeneousShuttle ;
2627import org .apache .calcite .rel .RelNode ;
2728import org .apache .calcite .rel .core .Aggregate ;
2829import org .apache .calcite .rel .core .AggregateCall ;
6162import org .apache .calcite .rex .RexSubQuery ;
6263import org .apache .calcite .rex .RexUtil ;
6364import org .apache .calcite .rex .RexVisitor ;
65+ import org .apache .calcite .rex .RexVisitorImpl ;
6466import org .apache .calcite .sql .SqlExplainFormat ;
6567import org .apache .calcite .sql .SqlExplainLevel ;
6668import org .apache .calcite .sql .SqlKind ;
7880import org .apache .calcite .util .mapping .Mappings ;
7981
8082import com .google .common .collect .ImmutableList ;
83+ import com .google .common .collect .ImmutableSet ;
84+ import com .google .common .collect .Iterables ;
8185
8286import org .checkerframework .checker .nullness .qual .Nullable ;
8387
@@ -473,6 +477,77 @@ public TrimResult trimFields(
473477 return result (newCalc , mapping , calc );
474478 }
475479
480+ /**
481+ * Shuttle that finds all {@link TableScan}s inside a given {@link RelNode}.
482+ */
483+ private static class TableScanCollector extends RelHomogeneousShuttle {
484+ private ImmutableSet .Builder <List <String >> builder = ImmutableSet .builder ();
485+
486+ /** Qualified names. */
487+ Set <List <String >> tables () {
488+ return builder .build ();
489+ }
490+
491+ @ Override public RelNode visit (TableScan scan ) {
492+ builder .add (scan .getTable ().getQualifiedName ());
493+ return super .visit (scan );
494+ }
495+ }
496+
497+ /**
498+ * Shuttle that finds all {@link TableScan}`s inside a given {@link RexNode}.
499+ */
500+ private static class InputTablesVisitor extends RexVisitorImpl <Void > {
501+ private ImmutableSet .Builder <List <String >> builder = ImmutableSet .builder ();
502+
503+ protected InputTablesVisitor () {
504+ super (false );
505+ }
506+
507+ /** Qualified names. */
508+ Set <List <String >> tables () {
509+ return builder .build ();
510+ }
511+
512+ @ Override public Void visitSubQuery (RexSubQuery subQuery ) {
513+ if (subQuery .getKind () == SqlKind .SCALAR_QUERY ) {
514+ subQuery .rel .accept (new RelHomogeneousShuttle () {
515+ @ Override public RelNode visit (TableScan scan ) {
516+ builder .add (scan .getTable ().getQualifiedName ());
517+ return super .visit (scan );
518+ }
519+ });
520+ }
521+ return null ;
522+ }
523+ }
524+
525+ private boolean inputContainsSubQueryTables (Project project , RelNode input ) {
526+ InputTablesVisitor inputSubQueryTablesCollector = new InputTablesVisitor ();
527+
528+ RexUtil .apply (inputSubQueryTablesCollector , project .getProjects (), null );
529+
530+ Set <List <String >> subQueryTables = inputSubQueryTablesCollector .tables ();
531+
532+ assert subQueryTables .isEmpty () || subQueryTables .size () == 1
533+ : "unexpected different tables in subquery: " + subQueryTables ;
534+
535+ TableScanCollector inputTablesCollector = new TableScanCollector ();
536+ input .accept (inputTablesCollector );
537+
538+ Set <List <String >> inputTables = inputTablesCollector .tables ();
539+ // Check for input and subquery tables intersection.
540+ if (!subQueryTables .isEmpty ()) {
541+ for (List <String > t : inputTables ) {
542+ if (t .equals (Iterables .getOnlyElement (subQueryTables ))) {
543+ return true ;
544+ }
545+ }
546+ }
547+
548+ return false ;
549+ }
550+
476551 /**
477552 * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
478553 * {@link org.apache.calcite.rel.logical.LogicalProject}.
@@ -488,18 +563,24 @@ public TrimResult trimFields(
488563 // Which fields are required from the input?
489564 final Set <RelDataTypeField > inputExtraFields =
490565 new LinkedHashSet <>(extraFields );
566+
567+ // Collect all the SubQueries in the projection list.
568+ List <RexSubQuery > subQueries = RexUtil .SubQueryCollector .collect (project );
569+ // Get all the correlationIds present in the SubQueries
570+ Set <CorrelationId > correlationIds = RelOptUtil .getVariablesUsed (subQueries );
571+ // Subquery lookup is required.
572+ boolean subQueryLookUp =
573+ !correlationIds .isEmpty () && inputContainsSubQueryTables (project , input );
574+
491575 RelOptUtil .InputFinder inputFinder =
492- new RelOptUtil .InputFinder (inputExtraFields );
576+ new RelOptUtil .SubQueryAwareInputFinder (inputExtraFields , subQueryLookUp );
577+
493578 for (Ord <RexNode > ord : Ord .zip (project .getProjects ())) {
494579 if (fieldsUsed .get (ord .i )) {
495580 ord .e .accept (inputFinder );
496581 }
497582 }
498583
499- // Collect all the SubQueries in the projection list.
500- List <RexSubQuery > subQueries = RexUtil .SubQueryCollector .collect (project );
501- // Get all the correlationIds present in the SubQueries
502- Set <CorrelationId > correlationIds = RelOptUtil .getVariablesUsed (subQueries );
503584 ImmutableBitSet requiredColumns = ImmutableBitSet .of ();
504585 if (!correlationIds .isEmpty ()) {
505586 assert correlationIds .size () == 1 ;
@@ -547,6 +628,14 @@ public TrimResult trimFields(
547628 if (fieldsUsed .get (ord .i )) {
548629 mapping .set (ord .i , newProjects .size ());
549630 RexNode newProjectExpr = ord .e .accept (shuttle );
631+ // Subquery need to be remapped
632+ if (newProjectExpr instanceof RexSubQuery
633+ && newProjectExpr .getKind () == SqlKind .SCALAR_QUERY
634+ && !correlationIds .isEmpty ()) {
635+ newProjectExpr =
636+ changeCorrelateReferences ((RexSubQuery ) newProjectExpr ,
637+ Iterables .getOnlyElement (correlationIds ), newInput .getRowType (), inputMapping );
638+ }
550639 newProjects .add (newProjectExpr );
551640 }
552641 }
@@ -561,6 +650,24 @@ public TrimResult trimFields(
561650 return result (newProject , mapping , project );
562651 }
563652
653+ private RexNode changeCorrelateReferences (
654+ RexSubQuery node ,
655+ CorrelationId corrId ,
656+ RelDataType rowType ,
657+ Mapping inputMapping ) {
658+ assert node .getKind () == SqlKind .SCALAR_QUERY : "Expected a SCALAR_QUERY, found "
659+ + node .getKind ();
660+ RelNode subQuery = node .rel ;
661+ RexBuilder rexBuilder = relBuilder .getRexBuilder ();
662+
663+ RexCorrelVariableMapShuttle rexVisitor =
664+ new RexCorrelVariableMapShuttle (corrId , rowType , inputMapping , rexBuilder );
665+ RelNode newSubQuery =
666+ subQuery .accept (new RexRewritingRelShuttle (rexVisitor ));
667+
668+ return RexSubQuery .scalar (newSubQuery );
669+ }
670+
564671 /** Creates a project with a dummy column, to protect the parts of the system
565672 * that cannot handle a relational expression with no columns.
566673 *
@@ -1437,9 +1544,12 @@ static class RexCorrelVariableMapShuttle extends RexShuttle {
14371544 (RexCorrelVariable ) fieldAccess .getReferenceExpr ();
14381545 if (referenceExpr .id .equals (correlationId )) {
14391546 int oldIndex = fieldAccess .getField ().getIndex ();
1547+ int newIndex = mapping .getTarget (oldIndex );
1548+ if (newIndex == oldIndex ) {
1549+ return super .visitFieldAccess (fieldAccess );
1550+ }
14401551 RexNode newCorrel =
14411552 rexBuilder .makeCorrel (newCorrelRowType , referenceExpr .id );
1442- int newIndex = mapping .getTarget (oldIndex );
14431553 return rexBuilder .makeFieldAccess (newCorrel , newIndex );
14441554 }
14451555 }
0 commit comments