Skip to content
335 changes: 329 additions & 6 deletions processing/src/main/java/org/apache/druid/math/expr/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import it.unimi.dsi.fastutil.objects.ObjectAVLTreeSet;
import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.HumanReadableBytes;
Expand Down Expand Up @@ -3233,6 +3234,10 @@
* Primarily internal helper function used to coerce null, [], and [null] into [null], similar to the logic done
* by {@link org.apache.druid.segment.virtual.ExpressionSelectors#supplierFromDimensionSelector} when the 3rd
* argument is true, which is done when implicitly mapping scalar functions over mvd values.
*
* Was formerly generated by the SQL layer for MV_CONTAINS and MV_OVERLAP, but is no longer generated, since the
* SQL layer now prefers using {@link MvContainsFunction} and {@link MvOverlapFunction}. This function remains here
* for backwards compatibility.
*/
class MultiValueStringHarmonizeNullsFunction implements Function
{
Expand All @@ -3245,11 +3250,7 @@
@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval eval = args.get(0).eval(bindings).castTo(ExpressionType.STRING_ARRAY);
if (eval.value() == null || eval.asArray().length == 0) {
return ExprEval.ofArray(ExpressionType.STRING_ARRAY, new Object[]{null});
}
return eval;
return harmonizeMultiValue(args.get(0).eval(bindings));
}

@Override
Expand Down Expand Up @@ -3698,7 +3699,7 @@
}

if (scalarEval.value() == null) {
return Arrays.asList(array).contains(null) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null);
return arrayContainsNull(array) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null);
}

final ExpressionType matchType = arrayEval.elementType();
Expand Down Expand Up @@ -4157,6 +4158,300 @@
}
}

class MvOverlapFunction implements Function
{
@Override
public String name()
{
return "mv_overlap";
}

@Nullable
@Override
public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
return ExpressionType.LONG;
}

@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval<?> arg1 = Function.harmonizeMultiValue(args.get(0).eval(bindings));
final ExprEval<?> arg2 = args.get(1).eval(bindings);

// Cast arg1 to arg2's type.
final Object[] array2 = arg2.asArray();
final ExpressionType array2Type = arg2.asArrayType();
final Object[] array1 = arg1.castTo(array2Type).asArray();

// If the second argument is null, check if the first argument contains null.
if (array2 == null) {
return ExprEval.ofLongBoolean(arrayContainsNull(array1));
}

// If the second argument is empty array, return false regardless of first argument.
if (array2.length == 0) {
return ExprEval.ofLongBoolean(false);
}

// Check for overlap.
final Set<Object> set2 = new ObjectOpenHashSet<>(array2);
for (final Object check : array1) {
if (set2.contains(check)) {
return ExprEval.ofLongBoolean(true);
}
}

// No overlap.
if (!set2.contains(null) && arrayContainsNull(array1)) {
return ExprEval.ofLong(null);
} else {
return ExprEval.ofLongBoolean(false);
}
}

@Override
public void validateArguments(List<Expr> args)
{
validationHelperCheckArgumentCount(args, 2);
}

@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
return Collections.emptySet();
}

@Override
public Set<Expr> getArrayInputs(List<Expr> args)
{
return ImmutableSet.copyOf(args);
}

@Override
public boolean hasArrayInputs()
{
return true;
}

@Override
public Function asSingleThreaded(List<Expr> args, Expr.InputBindingInspector inspector)
{
final Expr arg1 = args.get(0);
final Expr arg2 = args.get(1);

if (arg2.isLiteral()) {
final ExprEval<?> rhsEval = args.get(1).eval(InputBindings.nilBindings());
final Object[] rhsArray = rhsEval.asArray();

if (rhsArray == null) {
return new MvOverlapConstantNull();
} else if (rhsArray.length == 0) {
return new MvOverlapConstantEmpty();
} else if (rhsEval.elementType().isPrimitive()) {
return new MvOverlapConstantArray(
new ObjectOpenHashSet<>(rhsArray),
arrayContainsNull(rhsArray),
rhsEval.asArrayType()
);
}
}

return this;
}

private static final class MvOverlapConstantNull extends MvOverlapFunction
{
@Override
public ExprEval<?> apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval<?> arrayExpr1 = Function.harmonizeMultiValue(args.get(0).eval(bindings));
return ExprEval.ofLongBoolean(arrayContainsNull(arrayExpr1.asArray()));
}
}

private static final class MvOverlapConstantEmpty extends MvOverlapFunction
{
@Override
public ExprEval<?> apply(List<Expr> args, Expr.ObjectBinding bindings)
{
return ExprEval.ofLongBoolean(false);
}
}

private static final class MvOverlapConstantArray extends MvOverlapFunction
{
final Set<Object> matchValues;
final boolean rhsHasNull;
final ExpressionType matchArrayType;

public MvOverlapConstantArray(Set<Object> matchValues, boolean rhsHasNull, ExpressionType matchArrayType)
{
this.matchValues = matchValues;
this.rhsHasNull = rhsHasNull;
this.matchArrayType = matchArrayType;
}

@Override
public ExprEval<?> apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval<?> arrayExpr1 = Function.harmonizeMultiValue(args.get(0).eval(bindings));
final Object[] array1 = arrayExpr1.castTo(matchArrayType).asArray();

for (final Object check : array1) {
if (matchValues.contains(check)) {
return ExprEval.ofLongBoolean(true);
}
}

// No overlap.
if (!rhsHasNull && arrayContainsNull(array1)) {
return ExprEval.ofLong(null);
} else {
return ExprEval.ofLongBoolean(false);
}
}
}
}

class MvContainsFunction implements Function
{
@Override
public String name()
{
return "mv_contains";
}

@Nullable
@Override
public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
return ExpressionType.LONG;
}

@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval<?> arg1 = Function.harmonizeMultiValue(args.get(0).eval(bindings));
final ExprEval<?> arg2 = args.get(1).eval(bindings);

// Cast arg1 to arg2's type.
final Object[] array2 = arg2.asArray();
final ExpressionType array2Type = arg2.asArrayType();
final Object[] array1 = arg1.castTo(array2Type).asArray();

// If the second argument is null, check if the first argument contains null.
if (array2 == null) {
return ExprEval.ofLongBoolean(arrayContainsNull(array1));
}

// If the second argument is an empty array, return true regardless of the first argument.
if (array2.length == 0) {
return ExprEval.ofLongBoolean(true);
}

return ExprEval.ofLongBoolean(Arrays.asList(array1).containsAll(Arrays.asList(array2)));
}

@Override
public void validateArguments(List<Expr> args)
{
validationHelperCheckArgumentCount(args, 2);
}

@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
return Collections.emptySet();
}

@Override
public Set<Expr> getArrayInputs(List<Expr> args)
{
return ImmutableSet.copyOf(args);
}

@Override
public boolean hasArrayInputs()
{
return true;
}

@Override
public Function asSingleThreaded(List<Expr> args, Expr.InputBindingInspector inspector)
{
final Expr arg1 = args.get(0);
final Expr arg2 = args.get(1);

if (arg2.isLiteral()) {
final ExprEval<?> rhsEval = args.get(1).eval(InputBindings.nilBindings());
final Object[] rhsArray = rhsEval.asArray();

if (rhsArray == null) {
return new MvContainsConstantScalar(null, rhsEval.asArrayType());
} else if (rhsArray.length == 0) {
return new MvContainsConstantEmpty();
} else if (rhsArray.length == 1) {
return new MvContainsConstantScalar(rhsArray[0], rhsEval.asArrayType());
} else if (rhsEval.elementType().isPrimitive()) {
return new MvContainsConstantArray(rhsArray, rhsEval.asArrayType());
}
}

return this;
}

private static final class MvContainsConstantArray extends MvContainsFunction
{
private final List<Object> matchValues;
private final ExpressionType matchArrayType;

public MvContainsConstantArray(final Object[] matchValues, final ExpressionType matchArrayType)
{
this.matchValues = Arrays.asList(matchValues);
this.matchArrayType = matchArrayType;
}

@Override
public ExprEval<?> apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval<?> arrayExpr1 = Function.harmonizeMultiValue(args.get(0).eval(bindings));
final Object[] array1 = arrayExpr1.castTo(matchArrayType).asArray();
return ExprEval.ofLongBoolean(Arrays.asList(array1).containsAll(matchValues));
}
}

private static final class MvContainsConstantScalar extends MvContainsFunction
{
@Nullable
private final Object matchValue;
private final ExpressionType matchArrayType;

public MvContainsConstantScalar(@Nullable final Object matchValue, final ExpressionType matchArrayType)
{
this.matchValue = matchValue;
this.matchArrayType = matchArrayType;
}

@Override
public ExprEval<?> apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval<?> arrayExpr1 = Function.harmonizeMultiValue(args.get(0).eval(bindings));
final Object[] array1 = arrayExpr1.castTo(matchArrayType).asArray();
return ExprEval.ofLongBoolean(Arrays.asList(array1).contains(matchValue));
}
}

private static final class MvContainsConstantEmpty extends MvContainsFunction
{
@Override
public ExprEval<?> apply(List<Expr> args, Expr.ObjectBinding bindings)
{
return ExprEval.ofLongBoolean(true);
}
}
}

class ArraySliceFunction implements Function
{
@Override
Expand Down Expand Up @@ -4336,4 +4631,32 @@
return HumanReadableBytes.UnitSystem.DECIMAL;
}
}

/**
* Harmonizes values for usage as multi-value-dimension-like inputs. The returned value is always of type
* {@link ExpressionType#STRING_ARRAY}. Coerces null, [], and [null] into [null], similar to the logic done by
* {@link org.apache.druid.segment.virtual.ExpressionSelectors#supplierFromDimensionSelector} when "homogenize"
* is true.
*/
private static ExprEval<?> harmonizeMultiValue(ExprEval<?> eval)
{
final ExprEval<?> castEval = eval.castTo(ExpressionType.STRING_ARRAY);
if (castEval.value() == null || castEval.asArray().length == 0) {
return ExprEval.ofArray(ExpressionType.STRING_ARRAY, new Object[]{null});
}
return castEval;
}

/**
* Returns whether an array contains null.
*/
private static boolean arrayContainsNull(Object[] array)
{
for (Object obj : array) {
if (obj == null) {
return true;
}
}
return false;
}
}
Loading
Loading