Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions velox/functions/lib/TransformFunctionBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/vector/FunctionVector.h"

namespace facebook::velox::functions {

/// Base class for array transform functions.
/// Subclasses can override addIndexVector() to provide additional lambda
/// arguments (e.g., element index for Spark's transform function).
class TransformFunctionBase : public exec::VectorFunction {
public:
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
VELOX_CHECK_EQ(args.size(), 2);

// Flatten input array.
exec::LocalDecodedVector arrayDecoder(context, *args[0], rows);
auto& decodedArray = *arrayDecoder.get();

auto flatArray = flattenArray(rows, args[0], decodedArray);
auto newNumElements = flatArray->elements()->size();

std::vector<VectorPtr> lambdaArgs = {flatArray->elements()};

// Allow subclasses to add additional lambda arguments (e.g., index vector).
addIndexVector(args, flatArray, newNumElements, context, lambdaArgs);

SelectivityVector validRowsInReusedResult =
toElementRows<ArrayVector>(newNumElements, rows, flatArray.get());

// Transformed elements.
VectorPtr newElements;

auto elementToTopLevelRows = getElementToTopLevelRows(
newNumElements, rows, flatArray.get(), context.pool());

// Loop over lambda functions and apply these to elements of the base array.
// In most cases there will be only one function and the loop will run once.
auto it = args[1]->asUnchecked<FunctionVector>()->iterator(&rows);
while (auto entry = it.next()) {
auto elementRows = toElementRows<ArrayVector>(
newNumElements, *entry.rows, flatArray.get());
auto wrapCapture = toWrapCapture<ArrayVector>(
newNumElements, entry.callable, *entry.rows, flatArray);

entry.callable->apply(
elementRows,
&validRowsInReusedResult,
wrapCapture,
&context,
lambdaArgs,
elementToTopLevelRows,
&newElements);
}

// Set nulls for rows not present in 'rows'.
BufferPtr newNulls = addNullsForUnselectedRows(flatArray, rows);

VectorPtr localResult = std::make_shared<ArrayVector>(
flatArray->pool(),
outputType,
std::move(newNulls),
rows.end(),
flatArray->offsets(),
flatArray->sizes(),
newElements);
context.moveOrCopyResult(localResult, rows, result);
}

/// Returns the base signature: array(T), function(T, U) -> array(U).
/// Subclasses can call this and add additional signatures.
static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.typeVariable("U")
.returnType("array(U)")
.argumentType("array(T)")
.argumentType("function(T, U)")
.build()};
}

protected:
/// Override this method to add additional arguments to the lambda.
/// Default implementation does nothing (element-only transform).
/// @param args The original function arguments.
/// @param flatArray The flattened input array.
/// @param numElements Total number of elements across all arrays.
/// @param context The evaluation context.
/// @param lambdaArgs Output vector to append additional arguments to.
virtual void addIndexVector(
const std::vector<VectorPtr>& /*args*/,
const ArrayVectorPtr& /*flatArray*/,
vector_size_t /*numElements*/,
exec::EvalCtx& /*context*/,
std::vector<VectorPtr>& /*lambdaArgs*/) const {}
};

} // namespace facebook::velox::functions
81 changes: 5 additions & 76 deletions velox/functions/prestosql/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,88 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/vector/FunctionVector.h"
#include "velox/functions/lib/TransformFunctionBase.h"

namespace facebook::velox::functions {
namespace {

// See documentation at https://prestodb.io/docs/current/functions/array.html
class TransformFunction : public exec::VectorFunction {
public:
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
VELOX_CHECK_EQ(args.size(), 2);

// Flatten input array.
exec::LocalDecodedVector arrayDecoder(context, *args[0], rows);
auto& decodedArray = *arrayDecoder.get();

auto flatArray = flattenArray(rows, args[0], decodedArray);

std::vector<VectorPtr> lambdaArgs = {flatArray->elements()};
auto newNumElements = flatArray->elements()->size();

SelectivityVector validRowsInReusedResult =
toElementRows<ArrayVector>(newNumElements, rows, flatArray.get());

// transformed elements
VectorPtr newElements;

auto elementToTopLevelRows = getElementToTopLevelRows(
newNumElements, rows, flatArray.get(), context.pool());

// loop over lambda functions and apply these to elements of the base array;
// in most cases there will be only one function and the loop will run once
auto it = args[1]->asUnchecked<FunctionVector>()->iterator(&rows);
while (auto entry = it.next()) {
auto elementRows = toElementRows<ArrayVector>(
newNumElements, *entry.rows, flatArray.get());
auto wrapCapture = toWrapCapture<ArrayVector>(
newNumElements, entry.callable, *entry.rows, flatArray);

entry.callable->apply(
elementRows,
&validRowsInReusedResult,
wrapCapture,
&context,
lambdaArgs,
elementToTopLevelRows,
&newElements);
}

// Set nulls for rows not present in 'rows'.
BufferPtr newNulls = addNullsForUnselectedRows(flatArray, rows);

VectorPtr localResult = std::make_shared<ArrayVector>(
flatArray->pool(),
outputType,
std::move(newNulls),
rows.end(),
flatArray->offsets(),
flatArray->sizes(),
newElements);
context.moveOrCopyResult(localResult, rows, result);
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// array(T), function(T, U) -> array(U)
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.typeVariable("U")
.returnType("array(U)")
.argumentType("array(T)")
.argumentType("function(T, U)")
.build()};
}
class TransformFunction : public TransformFunctionBase {
// Inherits apply() and signatures() from base class.
// No additional lambda arguments needed for Presto.
};

} // namespace

/// transform is null preserving for the array. But since an
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ velox_add_library(
velox_functions_spark_impl
ArrayGetFunction.cpp
ArraySort.cpp
Transform.cpp
CharVarcharUtils.cpp
Comparisons.cpp
ConcatWs.cpp
Expand Down
109 changes: 109 additions & 0 deletions velox/functions/sparksql/Transform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/common/base/BitUtil.h"
#include "velox/functions/lib/TransformFunctionBase.h"

namespace facebook::velox::functions::sparksql {
namespace {

/// Spark's transform function supports both signatures:
/// 1. transform(array, x -> expr) - element only
/// 2. transform(array, (x, i) -> expr) - element + index (Spark-specific)
///
/// See Spark documentation:
/// https://spark.apache.org/docs/latest/api/sql/index.html#transform
class TransformFunction : public TransformFunctionBase {
public:
/// Returns both base signature and Spark-specific signature with index.
static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
auto sigs = TransformFunctionBase::signatures();
// Add Spark-specific signature: array(T), function(T, integer, U) ->
// array(U). Spark uses IntegerType (32-bit) for the index parameter.
sigs.push_back(
exec::FunctionSignatureBuilder()
.typeVariable("T")
.typeVariable("U")
.returnType("array(U)")
.argumentType("array(T)")
.argumentType("function(T, integer, U)")
.build());
return sigs;
}

protected:
/// Adds index vector to lambda arguments if the lambda expects it.
void addIndexVector(
const std::vector<VectorPtr>& args,
const ArrayVectorPtr& flatArray,
vector_size_t numElements,
exec::EvalCtx& context,
std::vector<VectorPtr>& lambdaArgs) const override {
// Check the lambda function type to see if it expects 2 input arguments.
// function(T, U) has 2 children (input T, output U) -> 1 input arg.
// function(T, integer, U) has 3 children (input T, index integer, output U)
// -> 2 input args.
auto functionType = args[1]->type();
bool withIndex = functionType->size() == 3;

if (withIndex) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the only change with presto TransformFunction is here, am I right? So I would prefer a base TransformFunction and add a virtual addIndexVectors(lambdaArgs) with default empty implementation,

lambdaArgs.push_back(createIndexVector(flatArray, numElements, context));
}
}

private:
/// Creates an index vector where each element contains its position within
/// its respective array. For example, if we have arrays [a, b] and [c, d, e],
/// the index vector will be [0, 1, 0, 1, 2].
/// Spark uses IntegerType (32-bit) for the index.
static VectorPtr createIndexVector(
const ArrayVectorPtr& flatArray,
vector_size_t numElements,
exec::EvalCtx& context) {
auto* pool = context.pool();
auto indexVector =
BaseVector::create<FlatVector<int32_t>>(INTEGER(), numElements, pool);

auto* rawOffsets = flatArray->rawOffsets();
auto* rawSizes = flatArray->rawSizes();
auto* rawNulls = flatArray->rawNulls();
auto* rawIndices = indexVector->mutableRawValues();

for (vector_size_t row = 0; row < flatArray->size(); ++row) {
// Skip null arrays.
if (rawNulls && bits::isBitNull(rawNulls, row)) {
continue;
}
auto offset = rawOffsets[row];
auto size = rawSizes[row];
for (vector_size_t i = 0; i < size; ++i) {
rawIndices[offset + i] = i;
}
}

return indexVector;
}
};

} // namespace

VELOX_DECLARE_VECTOR_FUNCTION_WITH_METADATA(
udf_spark_transform,
TransformFunction::signatures(),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
std::make_unique<TransformFunction>());

} // namespace facebook::velox::functions::sparksql
9 changes: 8 additions & 1 deletion velox/functions/sparksql/registration/RegisterArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ namespace facebook::velox::functions {
// vector function definition.
// Higher order functions.
void registerSparkArrayFunctions(const std::string& prefix) {
VELOX_REGISTER_VECTOR_FUNCTION(udf_transform, prefix + "transform");
VELOX_REGISTER_VECTOR_FUNCTION(udf_reduce, prefix + "aggregate");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_constructor, prefix + "array");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_contains, prefix + "array_contains");
Expand All @@ -52,6 +51,12 @@ void registerSparkArrayFunctions(const std::string& prefix) {
}

namespace sparksql {

// Register Spark-specific transform with index support.
void registerSparkTransform(const std::string& prefix) {
VELOX_REGISTER_VECTOR_FUNCTION(udf_spark_transform, prefix + "transform");
}

template <typename T>
inline void registerArrayConcatFunction(const std::string& prefix) {
registerFunction<
Expand Down Expand Up @@ -216,6 +221,8 @@ void registerArrayFunctions(const std::string& prefix) {
registerArrayRemoveFunctions(prefix);
registerArrayPrependFunctions(prefix);
registerSparkArrayFunctions(prefix);
// Register Spark-specific transform with index support.
registerSparkTransform(prefix);
// Register array sort functions.
exec::registerStatefulVectorFunction(
prefix + "array_sort", arraySortSignatures(true), makeArraySortAsc);
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ add_executable(
StringTest.cpp
StringToMapTest.cpp
ToJsonTest.cpp
TransformTest.cpp
UnBase64Test.cpp
UnscaledValueFunctionTest.cpp
UpperLowerTest.cpp
Expand Down
Loading
Loading