33import static io .substrait .isthmus .SqlConverterBase .EXTENSION_COLLECTION ;
44
55import com .google .common .collect .ImmutableList ;
6+ import com .google .common .collect .Range ;
7+ import com .google .common .collect .RangeMap ;
8+ import com .google .common .collect .TreeRangeMap ;
69import io .substrait .expression .Expression ;
710import io .substrait .expression .Expression .SortDirection ;
811import io .substrait .expression .FunctionArg ;
3437import java .util .ArrayList ;
3538import java .util .Collection ;
3639import java .util .Collections ;
40+ import java .util .HashSet ;
3741import java .util .List ;
3842import java .util .Optional ;
3943import java .util .OptionalLong ;
44+ import java .util .Stack ;
4045import java .util .stream .Collectors ;
4146import java .util .stream .IntStream ;
4247import java .util .stream .Stream ;
4954import org .apache .calcite .rel .RelFieldCollation ;
5055import org .apache .calcite .rel .RelNode ;
5156import org .apache .calcite .rel .core .AggregateCall ;
57+ import org .apache .calcite .rel .core .CorrelationId ;
5258import org .apache .calcite .rel .core .JoinRelType ;
5359import org .apache .calcite .rel .core .TableModify ;
5460import org .apache .calcite .rel .logical .LogicalTableModify ;
@@ -156,8 +162,11 @@ public static RelNode convert(
156162 @ Override
157163 public RelNode visit (Filter filter , Context context ) throws RuntimeException {
158164 RelNode input = filter .getInput ().accept (this , context );
165+ context .pushOuterRowType (input .getRowType ());
159166 RexNode filterCondition = filter .getCondition ().accept (expressionRexConverter , context );
160- RelNode node = relBuilder .push (input ).filter (filterCondition ).build ();
167+ RelNode node =
168+ relBuilder .push (input ).filter (context .popCorrelationIds (), filterCondition ).build ();
169+ context .popOuterRowType ();
161170 return applyRemap (node , filter .getRemap ());
162171 }
163172
@@ -183,6 +192,8 @@ public RelNode visit(EmptyScan emptyScan, Context context) throws RuntimeExcepti
183192 @ Override
184193 public RelNode visit (Project project , Context context ) throws RuntimeException {
185194 RelNode child = project .getInput ().accept (this , context );
195+ context .pushOuterRowType (child .getRowType ());
196+
186197 Stream <RexNode > directOutputs =
187198 IntStream .range (0 , child .getRowType ().getFieldCount ())
188199 .mapToObj (fieldIndex -> rexBuilder .makeInputRef (child , fieldIndex ));
@@ -193,7 +204,12 @@ public RelNode visit(Project project, Context context) throws RuntimeException {
193204 List <RexNode > rexExprs =
194205 Stream .concat (directOutputs , exprs ).collect (java .util .stream .Collectors .toList ());
195206
196- RelNode node = relBuilder .push (child ).project (rexExprs ).build ();
207+ RelNode node =
208+ relBuilder
209+ .push (child )
210+ .project (rexExprs , List .of (), false , context .popCorrelationIds ())
211+ .build ();
212+ context .popOuterRowType ();
197213 return applyRemap (node , project .getRemap ());
198214 }
199215
@@ -211,12 +227,19 @@ public RelNode visit(Cross cross, Context context) throws RuntimeException {
211227 public RelNode visit (Join join , Context context ) throws RuntimeException {
212228 RelNode left = join .getLeft ().accept (this , context );
213229 RelNode right = join .getRight ().accept (this , context );
230+ context .pushOuterRowType (left .getRowType (), right .getRowType ());
214231 RexNode condition =
215232 join .getCondition ()
216233 .map (c -> c .accept (expressionRexConverter , context ))
217234 .orElse (relBuilder .literal (true ));
218235 JoinRelType joinType = asJoinRelType (join );
219- RelNode node = relBuilder .push (left ).push (right ).join (joinType , condition ).build ();
236+ RelNode node =
237+ relBuilder
238+ .push (left )
239+ .push (right )
240+ .join (joinType , condition , context .popCorrelationIds ())
241+ .build ();
242+ context .popOuterRowType ();
220243 return applyRemap (node , join .getRemap ());
221244 }
222245
@@ -626,9 +649,101 @@ private RelNode applyRemap(RelNode relNode, Rel.Remap remap) {
626649 return relBuilder .push (relNode ).project (rexList ).build ();
627650 }
628651
652+ /** A shared context for the Substrait to RelNode conversion. */
629653 public static class Context implements VisitationContext {
654+ protected final Stack <RangeMap <Integer , RelDataType >> outerRowTypes = new Stack <>();
655+
656+ protected final Stack <java .util .Set <CorrelationId >> correlationIds = new Stack <>();
657+
658+ private int subqueryDepth ;
659+
660+ /**
661+ * Creates a new {@link Context} instance.
662+ *
663+ * @return the new {@link Context} instance
664+ */
630665 public static Context newContext () {
631666 return new Context ();
632667 }
668+
669+ /**
670+ * Adds the outer row types to the top of the stack of outer row types.
671+ *
672+ * <p>Row types are stored as a {@link RangeMap} with field indices as keys and the {@link
673+ * RelDataType} row type containing the field at the field index by continuously numbering the
674+ * field indices from 0 across all provided row types in the order the row types are passed as
675+ * arguments.
676+ *
677+ * @param inputs the row types to add
678+ */
679+ public void pushOuterRowType (final RelDataType ... inputs ) {
680+ final RangeMap <Integer , RelDataType > fieldRangeMap = TreeRangeMap .create ();
681+ int begin = 0 ;
682+ for (final RelDataType parent : inputs ) {
683+ final int end = begin + parent .getFieldCount ();
684+ final Range <Integer > range = Range .closedOpen (begin , end );
685+ fieldRangeMap .put (range , parent );
686+ begin = end ;
687+ }
688+
689+ outerRowTypes .push (fieldRangeMap );
690+ this .correlationIds .push (new HashSet <>());
691+ }
692+
693+ public void popOuterRowType () {
694+ outerRowTypes .pop ();
695+ }
696+
697+ /**
698+ * Returns the outer row type {@link RangeMap} walking up the given steps from the current
699+ * subquery depth.
700+ *
701+ * @param stepsOut number of steps to walk up from the current subquery depth
702+ * @return {@link RangeMap} with field indices as keys and the {@link RelDataType} row type
703+ * containing the field at the field index
704+ */
705+ public RangeMap <Integer , RelDataType > getOuterRowTypeRangeMap (final Integer stepsOut ) {
706+ return this .outerRowTypes .get (subqueryDepth - stepsOut );
707+ }
708+
709+ /**
710+ * Removes the correlation ids at the top of the stack.
711+ *
712+ * @return the correlation ids removed from the top of the stack
713+ */
714+ public java .util .Set <CorrelationId > popCorrelationIds () {
715+ return correlationIds .pop ();
716+ }
717+
718+ /**
719+ * Adds a {@link CorrelationId} to the subquery depth walking up the given steps from the
720+ * current subquery depth.
721+ *
722+ * @param stepsOut number of steps to walk up from the current subquery depth
723+ * @param correlationId the {@link CorrelationId} to add
724+ */
725+ public void addCorrelationId (final int stepsOut , final CorrelationId correlationId ) {
726+ final int index = subqueryDepth - stepsOut ;
727+ this .correlationIds .get (index ).add (correlationId );
728+ }
729+
730+ /** Increments the current subquery depth. */
731+ public void incrementSubqueryDepth () {
732+ this .subqueryDepth ++;
733+ }
734+
735+ /** Decrements the current subquery depth. */
736+ public void decrementSubqueryDepth () {
737+ this .subqueryDepth --;
738+ }
739+ }
740+
741+ /**
742+ * Returns the {@link RelBuilder} of this converter.
743+ *
744+ * @return the {@link RelBuilder}
745+ */
746+ public RelBuilder getRelBuilder () {
747+ return relBuilder ;
633748 }
634749}
0 commit comments