Skip to content

Commit 77e7f8f

Browse files
authored
fix(isthmus): tpcds q67 (#503)
1 parent 16f4b27 commit 77e7f8f

File tree

3 files changed

+65
-3
lines changed

3 files changed

+65
-3
lines changed

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,8 +620,19 @@ protected Expand newExpand(ExpandRel rel) {
620620

621621
protected Aggregate newAggregate(AggregateRel rel) {
622622
Rel input = from(rel.getInput());
623+
Type.Struct inputSchema;
624+
if (input instanceof Project) {
625+
List<Type> types =
626+
((Project) input)
627+
.getExpressions().stream().map(Expression::getType).collect(Collectors.toList());
628+
inputSchema = Type.Struct.builder().fields(types).nullable(false).build();
629+
} else {
630+
inputSchema = input.getRecordType();
631+
}
632+
623633
ProtoExpressionConverter protoExprConverter =
624-
new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
634+
new ProtoExpressionConverter(lookup, extensions, inputSchema, this);
635+
625636
ProtoAggregateFunctionConverter protoAggrFuncConverter =
626637
new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter);
627638

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,60 @@ public Rel visit(org.apache.calcite.rel.core.Minus minus) {
244244
return Set.builder().inputs(inputs).setOp(setOp).build();
245245
}
246246

247+
/**
248+
* Pre-processes the input to an Aggregate relation to handle nullability changes introduced by
249+
* ROLLUP/CUBE/GROUPING SETS.
250+
*
251+
* @param aggregate The original Calcite aggregate node.
252+
* @return A Substrait Rel node that is correctly typed to be the input to the Substrait
253+
* Aggregate.
254+
*/
255+
private Rel handleRollupCorrection(org.apache.calcite.rel.core.Aggregate aggregate) {
256+
Rel originalInput = apply(aggregate.getInput());
257+
258+
// Determine the correct final output type for the aggregate, which accounts for nullability.
259+
NamedStruct aggregateOutputType = typeConverter.toNamedStruct(aggregate.getRowType());
260+
List<Integer> groupKeyIndices = aggregate.getGroupSet().asList();
261+
262+
// Create a list of expressions to cast the original input to the correct final type if needed.
263+
List<Expression> castExpressions = new ArrayList<>();
264+
265+
boolean needsCasting = false;
266+
for (int i = 0; i < originalInput.getRecordType().fields().size(); i++) {
267+
Expression fieldReference = FieldReference.newInputRelReference(i, originalInput);
268+
269+
if (groupKeyIndices.contains(i)) {
270+
int groupKeyOutputIndex = groupKeyIndices.indexOf(i);
271+
Type finalType = aggregateOutputType.struct().fields().get(groupKeyOutputIndex);
272+
273+
if (finalType.nullable() && !fieldReference.getType().nullable()) {
274+
needsCasting = true; // Mark that a cast is necessary.
275+
castExpressions.add(
276+
Expression.Cast.builder()
277+
.type(finalType)
278+
.input(fieldReference)
279+
.failureBehavior(Expression.FailureBehavior.RETURN_NULL)
280+
.build());
281+
} else {
282+
castExpressions.add(fieldReference);
283+
}
284+
} else {
285+
castExpressions.add(fieldReference);
286+
}
287+
}
288+
289+
// Only add the extra Project node if a cast was actually needed.
290+
if (needsCasting) {
291+
return Project.builder().input(originalInput).expressions(castExpressions).build();
292+
}
293+
294+
// If no casting was needed, just return the original converted input.
295+
return originalInput;
296+
}
297+
247298
@Override
248299
public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
249-
Rel input = apply(aggregate.getInput());
300+
Rel input = handleRollupCorrection(aggregate);
250301
Stream<ImmutableBitSet> sets;
251302
if (aggregate.groupSets != null) {
252303
sets = aggregate.groupSets.stream();

isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
public class TpcdsQueryTest extends PlanTestBase {
1616
private static final Set<Integer> toSubstraitExclusions = Set.of(9, 27, 36, 70, 86);
1717
private static final Set<Integer> fromSubstraitPojoExclusions = Set.of(1, 30, 81);
18-
private static final Set<Integer> fromSubstraitProtoExclusions = Set.of(1, 30, 67, 81);
18+
private static final Set<Integer> fromSubstraitProtoExclusions = Set.of(1, 30, 81);
1919

2020
static IntStream testCases() {
2121
return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n));

0 commit comments

Comments
 (0)