Skip to content

Commit 7cf1ccf

Browse files
authored
fix(isthmus): handle subqueries with outer field references (#426)
fixes #382 Signed-off-by: Niels Pardon <[email protected]>
1 parent 0d09f7a commit 7cf1ccf

File tree

7 files changed

+535
-15
lines changed

7 files changed

+535
-15
lines changed

core/src/main/java/io/substrait/dsl/SubstraitBuilder.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io.substrait.expression.Expression.FailureBehavior;
77
import io.substrait.expression.Expression.IfClause;
88
import io.substrait.expression.Expression.IfThen;
9+
import io.substrait.expression.Expression.PredicateOp;
910
import io.substrait.expression.Expression.SingleOrList;
1011
import io.substrait.expression.Expression.Switch;
1112
import io.substrait.expression.Expression.SwitchClause;
@@ -644,6 +645,14 @@ public Expression.ScalarFunctionInvocation equal(Expression left, Expression rig
644645
DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right);
645646
}
646647

648+
public Expression.ScalarFunctionInvocation and(Expression... args) {
649+
// If any arg is nullable, the output of and is potentially nullable
650+
// For example: false and null = null
651+
boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable());
652+
Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN;
653+
return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "and:bool", outputType, args);
654+
}
655+
647656
public Expression.ScalarFunctionInvocation or(Expression... args) {
648657
// If any arg is nullable, the output of or is potentially nullable
649658
// For example: false or null = null
@@ -706,4 +715,15 @@ public Plan plan(Plan.Root root) {
706715
public Rel.Remap remap(Integer... fields) {
707716
return Rel.Remap.of(Arrays.asList(fields));
708717
}
718+
719+
public Expression scalarSubquery(Rel input, Type type) {
720+
return Expression.ScalarSubquery.builder().input(input).type(type).build();
721+
}
722+
723+
public Expression exists(Rel rel) {
724+
return Expression.SetPredicate.builder()
725+
.tuples(rel)
726+
.predicateOp(PredicateOp.PREDICATE_OP_EXISTS)
727+
.build();
728+
}
709729
}

core/src/main/java/io/substrait/expression/FieldReference.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
3636
}
3737

3838
public boolean isSimpleRootReference() {
39-
return segments().size() == 1 && !inputExpression().isPresent();
39+
return segments().size() == 1
40+
&& !inputExpression().isPresent()
41+
&& !outerReferenceStepsOut().isPresent();
42+
}
43+
44+
public boolean isOuterReference() {
45+
return outerReferenceStepsOut().orElse(0) > 0;
4046
}
4147

4248
public FieldReference dereferenceStruct(int index) {

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
44

55
import com.google.common.collect.ImmutableList;
6+
import com.google.common.collect.Range;
7+
import com.google.common.collect.RangeMap;
8+
import com.google.common.collect.TreeRangeMap;
69
import io.substrait.expression.Expression;
710
import io.substrait.expression.Expression.SortDirection;
811
import io.substrait.expression.FunctionArg;
@@ -34,9 +37,11 @@
3437
import java.util.ArrayList;
3538
import java.util.Collection;
3639
import java.util.Collections;
40+
import java.util.HashSet;
3741
import java.util.List;
3842
import java.util.Optional;
3943
import java.util.OptionalLong;
44+
import java.util.Stack;
4045
import java.util.stream.Collectors;
4146
import java.util.stream.IntStream;
4247
import java.util.stream.Stream;
@@ -49,6 +54,7 @@
4954
import org.apache.calcite.rel.RelFieldCollation;
5055
import org.apache.calcite.rel.RelNode;
5156
import org.apache.calcite.rel.core.AggregateCall;
57+
import org.apache.calcite.rel.core.CorrelationId;
5258
import org.apache.calcite.rel.core.JoinRelType;
5359
import org.apache.calcite.rel.core.TableModify;
5460
import org.apache.calcite.rel.logical.LogicalTableModify;
@@ -156,8 +162,11 @@ public static RelNode convert(
156162
@Override
157163
public RelNode visit(Filter filter, Context context) throws RuntimeException {
158164
RelNode input = filter.getInput().accept(this, context);
165+
context.pushOuterRowType(input.getRowType());
159166
RexNode filterCondition = filter.getCondition().accept(expressionRexConverter, context);
160-
RelNode node = relBuilder.push(input).filter(filterCondition).build();
167+
RelNode node =
168+
relBuilder.push(input).filter(context.popCorrelationIds(), filterCondition).build();
169+
context.popOuterRowType();
161170
return applyRemap(node, filter.getRemap());
162171
}
163172

@@ -183,6 +192,8 @@ public RelNode visit(EmptyScan emptyScan, Context context) throws RuntimeExcepti
183192
@Override
184193
public RelNode visit(Project project, Context context) throws RuntimeException {
185194
RelNode child = project.getInput().accept(this, context);
195+
context.pushOuterRowType(child.getRowType());
196+
186197
Stream<RexNode> directOutputs =
187198
IntStream.range(0, child.getRowType().getFieldCount())
188199
.mapToObj(fieldIndex -> rexBuilder.makeInputRef(child, fieldIndex));
@@ -193,7 +204,12 @@ public RelNode visit(Project project, Context context) throws RuntimeException {
193204
List<RexNode> rexExprs =
194205
Stream.concat(directOutputs, exprs).collect(java.util.stream.Collectors.toList());
195206

196-
RelNode node = relBuilder.push(child).project(rexExprs).build();
207+
RelNode node =
208+
relBuilder
209+
.push(child)
210+
.project(rexExprs, List.of(), false, context.popCorrelationIds())
211+
.build();
212+
context.popOuterRowType();
197213
return applyRemap(node, project.getRemap());
198214
}
199215

@@ -211,12 +227,19 @@ public RelNode visit(Cross cross, Context context) throws RuntimeException {
211227
public RelNode visit(Join join, Context context) throws RuntimeException {
212228
RelNode left = join.getLeft().accept(this, context);
213229
RelNode right = join.getRight().accept(this, context);
230+
context.pushOuterRowType(left.getRowType(), right.getRowType());
214231
RexNode condition =
215232
join.getCondition()
216233
.map(c -> c.accept(expressionRexConverter, context))
217234
.orElse(relBuilder.literal(true));
218235
JoinRelType joinType = asJoinRelType(join);
219-
RelNode node = relBuilder.push(left).push(right).join(joinType, condition).build();
236+
RelNode node =
237+
relBuilder
238+
.push(left)
239+
.push(right)
240+
.join(joinType, condition, context.popCorrelationIds())
241+
.build();
242+
context.popOuterRowType();
220243
return applyRemap(node, join.getRemap());
221244
}
222245

@@ -626,9 +649,101 @@ private RelNode applyRemap(RelNode relNode, Rel.Remap remap) {
626649
return relBuilder.push(relNode).project(rexList).build();
627650
}
628651

652+
/** A shared context for the Substrait to RelNode conversion. */
629653
public static class Context implements VisitationContext {
654+
protected final Stack<RangeMap<Integer, RelDataType>> outerRowTypes = new Stack<>();
655+
656+
protected final Stack<java.util.Set<CorrelationId>> correlationIds = new Stack<>();
657+
658+
private int subqueryDepth;
659+
660+
/**
661+
* Creates a new {@link Context} instance.
662+
*
663+
* @return the new {@link Context} instance
664+
*/
630665
public static Context newContext() {
631666
return new Context();
632667
}
668+
669+
/**
670+
* Adds the outer row types to the top of the stack of outer row types.
671+
*
672+
* <p>Row types are stored as a {@link RangeMap} with field indices as keys and the {@link
673+
* RelDataType} row type containing the field at the field index by continuously numbering the
674+
* field indices from 0 across all provided row types in the order the row types are passed as
675+
* arguments.
676+
*
677+
* @param inputs the row types to add
678+
*/
679+
public void pushOuterRowType(final RelDataType... inputs) {
680+
final RangeMap<Integer, RelDataType> fieldRangeMap = TreeRangeMap.create();
681+
int begin = 0;
682+
for (final RelDataType parent : inputs) {
683+
final int end = begin + parent.getFieldCount();
684+
final Range<Integer> range = Range.closedOpen(begin, end);
685+
fieldRangeMap.put(range, parent);
686+
begin = end;
687+
}
688+
689+
outerRowTypes.push(fieldRangeMap);
690+
this.correlationIds.push(new HashSet<>());
691+
}
692+
693+
public void popOuterRowType() {
694+
outerRowTypes.pop();
695+
}
696+
697+
/**
698+
* Returns the outer row type {@link RangeMap} walking up the given steps from the current
699+
* subquery depth.
700+
*
701+
* @param stepsOut number of steps to walk up from the current subquery depth
702+
* @return {@link RangeMap} with field indices as keys and the {@link RelDataType} row type
703+
* containing the field at the field index
704+
*/
705+
public RangeMap<Integer, RelDataType> getOuterRowTypeRangeMap(final Integer stepsOut) {
706+
return this.outerRowTypes.get(subqueryDepth - stepsOut);
707+
}
708+
709+
/**
710+
* Removes the correlation ids at the top of the stack.
711+
*
712+
* @return the correlation ids removed from the top of the stack
713+
*/
714+
public java.util.Set<CorrelationId> popCorrelationIds() {
715+
return correlationIds.pop();
716+
}
717+
718+
/**
719+
* Adds a {@link CorrelationId} to the subquery depth walking up the given steps from the
720+
* current subquery depth.
721+
*
722+
* @param stepsOut number of steps to walk up from the current subquery depth
723+
* @param correlationId the {@link CorrelationId} to add
724+
*/
725+
public void addCorrelationId(final int stepsOut, final CorrelationId correlationId) {
726+
final int index = subqueryDepth - stepsOut;
727+
this.correlationIds.get(index).add(correlationId);
728+
}
729+
730+
/** Increments the current subquery depth. */
731+
public void incrementSubqueryDepth() {
732+
this.subqueryDepth++;
733+
}
734+
735+
/** Decrements the current subquery depth. */
736+
public void decrementSubqueryDepth() {
737+
this.subqueryDepth--;
738+
}
739+
}
740+
741+
/**
742+
* Returns the {@link RelBuilder} of this converter.
743+
*
744+
* @return the {@link RelBuilder}
745+
*/
746+
public RelBuilder getRelBuilder() {
747+
return relBuilder;
633748
}
634749
}

isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package io.substrait.isthmus.expression;
22

33
import com.google.common.collect.ImmutableList;
4+
import com.google.common.collect.Range;
5+
import com.google.common.collect.RangeMap;
46
import io.substrait.expression.AbstractExpressionVisitor;
57
import io.substrait.expression.EnumArg;
68
import io.substrait.expression.Expression;
@@ -33,6 +35,7 @@
3335
import java.util.stream.Stream;
3436
import org.apache.calcite.avatica.util.ByteString;
3537
import org.apache.calcite.rel.RelNode;
38+
import org.apache.calcite.rel.core.CorrelationId;
3639
import org.apache.calcite.rel.type.RelDataType;
3740
import org.apache.calcite.rel.type.RelDataTypeFactory;
3841
import org.apache.calcite.rex.RexBuilder;
@@ -513,7 +516,9 @@ private boolean isDistinct(Expression.WindowFunctionInvocation expr) {
513516
public RexNode visit(Expression.InPredicate expr, Context context) throws RuntimeException {
514517
List<RexNode> needles =
515518
expr.needles().stream().map(e -> e.accept(this, context)).collect(Collectors.toList());
519+
context.incrementSubqueryDepth();
516520
RelNode rel = expr.haystack().accept(relNodeConverter, context);
521+
context.decrementSubqueryDepth();
517522
return RexSubQuery.in(rel, ImmutableList.copyOf(needles));
518523
}
519524

@@ -589,13 +594,37 @@ public RexNode visit(Expression.Cast expr, Context context) throws RuntimeExcept
589594
@Override
590595
public RexNode visit(FieldReference expr, Context context) throws RuntimeException {
591596
if (expr.isSimpleRootReference()) {
592-
ReferenceSegment segment = expr.segments().get(0);
597+
final ReferenceSegment segment = expr.segments().get(0);
593598

594-
RexInputRef rexInputRef;
599+
final RexInputRef rexInputRef;
595600
if (segment instanceof FieldReference.StructField) {
596-
FieldReference.StructField f = (FieldReference.StructField) segment;
601+
final FieldReference.StructField field = (FieldReference.StructField) segment;
597602
rexInputRef =
598-
new RexInputRef(f.offset(), typeConverter.toCalcite(typeFactory, expr.getType()));
603+
new RexInputRef(field.offset(), typeConverter.toCalcite(typeFactory, expr.getType()));
604+
} else {
605+
throw new IllegalArgumentException("Unhandled type: " + segment);
606+
}
607+
608+
return rexInputRef;
609+
} else if (expr.isOuterReference()) {
610+
final ReferenceSegment segment = expr.segments().get(0);
611+
612+
final RexNode rexInputRef;
613+
if (segment instanceof FieldReference.StructField) {
614+
final FieldReference.StructField field = (FieldReference.StructField) segment;
615+
616+
final RangeMap<Integer, RelDataType> fieldRangeMap =
617+
context.getOuterRowTypeRangeMap(expr.outerReferenceStepsOut().get());
618+
final Range<Integer> range = fieldRangeMap.getEntry(field.offset()).getKey();
619+
final int fieldOffset = field.offset() - range.lowerEndpoint();
620+
621+
final CorrelationId correlationId =
622+
relNodeConverter.getRelBuilder().getCluster().createCorrel();
623+
context.addCorrelationId(expr.outerReferenceStepsOut().get(), correlationId);
624+
rexInputRef =
625+
rexBuilder.makeFieldAccess(
626+
rexBuilder.makeCorrel(fieldRangeMap.get(field.offset()), correlationId),
627+
fieldOffset);
599628
} else {
600629
throw new IllegalArgumentException("Unhandled type: " + segment);
601630
}
@@ -646,13 +675,17 @@ public RexNode visitEnumArg(
646675

647676
@Override
648677
public RexNode visit(ScalarSubquery expr, Context context) throws RuntimeException {
678+
context.incrementSubqueryDepth();
649679
RelNode inputRelnode = expr.input().accept(relNodeConverter, context);
680+
context.decrementSubqueryDepth();
650681
return RexSubQuery.scalar(inputRelnode);
651682
}
652683

653684
@Override
654685
public RexNode visit(SetPredicate expr, Context context) throws RuntimeException {
686+
context.incrementSubqueryDepth();
655687
RelNode inputRelnode = expr.tuples().accept(relNodeConverter, context);
688+
context.decrementSubqueryDepth();
656689
switch (expr.predicateOp()) {
657690
case PREDICATE_OP_EXISTS:
658691
return RexSubQuery.exists(inputRelnode);

isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.isthmus;
22

33
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
45

56
import io.substrait.proto.Plan;
67
import java.io.IOException;
@@ -13,7 +14,7 @@
1314
/** TPC-DS test to convert SQL to Substrait and then convert those plans back to SQL. */
1415
public class TpcdsQueryTest extends PlanTestBase {
1516
private static final Set<Integer> toSubstraitExclusions = Set.of(9, 27, 36, 70, 86);
16-
private static final Set<Integer> fromSubstraitExclusions = Set.of(6, 8, 67);
17+
private static final Set<Integer> fromSubstraitExclusions = Set.of(1, 30, 67, 81);
1718

1819
static IntStream testCases() {
1920
return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n));
@@ -32,6 +33,8 @@ public void testQuery(int query) throws IOException {
3233

3334
if (!fromSubstraitExclusions.contains(query)) {
3435
assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL");
36+
} else {
37+
assertThrows(Throwable.class, () -> toSql(plan), "Substrait to SQL");
3538
}
3639
}
3740

isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44

55
import io.substrait.proto.Plan;
66
import java.io.IOException;
7-
import java.util.Set;
87
import java.util.stream.IntStream;
98
import org.apache.calcite.sql.parser.SqlParseException;
109
import org.junit.jupiter.params.ParameterizedTest;
1110
import org.junit.jupiter.params.provider.MethodSource;
1211

1312
/** TPC-H test to convert SQL to Substrait and then convert those plans back to SQL. */
1413
public class TpchQueryTest extends PlanTestBase {
15-
private static final Set<Integer> fromSubstraitExclusions = Set.of(17);
16-
1714
static IntStream testCases() {
1815
return IntStream.rangeClosed(1, 22);
1916
}
@@ -29,9 +26,7 @@ public void testQuery(int query) throws IOException {
2926

3027
Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait");
3128

32-
if (!fromSubstraitExclusions.contains(query)) {
33-
assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL");
34-
}
29+
assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL");
3530
}
3631

3732
private Plan toSubstraitPlan(String sql) throws SqlParseException {

0 commit comments

Comments
 (0)