Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ velox_add_library(
velox_functions_spark_impl
ArrayGetFunction.cpp
ArraySort.cpp
Transform.cpp
CharVarcharUtils.cpp
Comparisons.cpp
ConcatWs.cpp
Expand Down
171 changes: 171 additions & 0 deletions velox/functions/sparksql/Transform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/common/base/BitUtil.h"

Check warning on line 17 in velox/functions/sparksql/Transform.cpp

View workflow job for this annotation

GitHub Actions / Build with GCC / Linux release with adapters

misc-include-cleaner

included header BitUtil.h is not used directly
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/vector/FunctionVector.h"

namespace facebook::velox::functions::sparksql {
namespace {

/// Spark's transform function supporting both signatures:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supporting -> supports

/// 1. transform(array, x -> expr) - element only
/// 2. transform(array, (x, i) -> expr) - element + index (Spark-specific)
///
/// See Spark documentation:
/// https://spark.apache.org/docs/latest/api/sql/index.html#transform
class SparkTransformFunction : public exec::VectorFunction {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove Spark prefix since it has been in sparksql namespace

public:
void apply(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse some code of udf_transform, move some of them to common path velox/functions/lib?

const SelectivityVector& rows,

Check warning on line 36 in velox/functions/sparksql/Transform.cpp

View workflow job for this annotation

GitHub Actions / Build with GCC / Linux release with adapters

misc-include-cleaner

no header providing "facebook::velox::SelectivityVector" is directly included
std::vector<VectorPtr>& args,

Check warning on line 37 in velox/functions/sparksql/Transform.cpp

View workflow job for this annotation

GitHub Actions / Build with GCC / Linux release with adapters

misc-include-cleaner

no header providing "facebook::velox::VectorPtr" is directly included

Check warning on line 37 in velox/functions/sparksql/Transform.cpp

View workflow job for this annotation

GitHub Actions / Build with GCC / Linux release with adapters

misc-include-cleaner

no header providing "std::vector" is directly included
const TypePtr& outputType,

Check warning on line 38 in velox/functions/sparksql/Transform.cpp

View workflow job for this annotation

GitHub Actions / Build with GCC / Linux release with adapters

misc-include-cleaner

no header providing "facebook::velox::TypePtr" is directly included
exec::EvalCtx& context,
VectorPtr& result) const override {
VELOX_CHECK_EQ(args.size(), 2);

Check warning on line 41 in velox/functions/sparksql/Transform.cpp

View workflow job for this annotation

GitHub Actions / Build with GCC / Linux release with adapters

misc-include-cleaner

no header providing "VELOX_CHECK_EQ" is directly included

// Flatten input array.
exec::LocalDecodedVector arrayDecoder(context, *args[0], rows);

Check warning on line 44 in velox/functions/sparksql/Transform.cpp

View workflow job for this annotation

GitHub Actions / Build with GCC / Linux release with adapters

misc-include-cleaner

no header providing "facebook::velox::exec::LocalDecodedVector" is directly included
auto& decodedArray = *arrayDecoder.get();

Check warning on line 45 in velox/functions/sparksql/Transform.cpp

View workflow job for this annotation

GitHub Actions / Build with GCC / Linux release with adapters

readability-redundant-smartptr-get

redundant get() call on smart pointer

auto flatArray = flattenArray(rows, args[0], decodedArray);
auto newNumElements = flatArray->elements()->size();

// Determine if we need to pass index to the lambda.
// Check the lambda function type to see if it expects 2 input arguments.
// function(T, U) has 2 children (input T, output U) -> 1 input arg.
// function(T, integer, U) has 3 children (input T, index integer, output U)
// -> 2 input args.
auto functionType = args[1]->type();
bool withIndex = functionType->size() == 3;

Check warning on line 56 in velox/functions/sparksql/Transform.cpp

View workflow job for this annotation

GitHub Actions / Build with GCC / Linux release with adapters

misc-const-correctness

variable 'withIndex' of type 'bool' can be declared 'const'

std::vector<VectorPtr> lambdaArgs = {flatArray->elements()};

// If lambda expects index, create index vector.
if (withIndex) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the only change with presto TransformFunction is here, am I right? So I would prefer a base TransformFunction and add a virtual addIndexVectors(lambdaArgs) with default empty implementation,

auto indexVector = createIndexVector(flatArray, newNumElements, context);
lambdaArgs.push_back(indexVector);
}

SelectivityVector validRowsInReusedResult =

Check warning on line 66 in velox/functions/sparksql/Transform.cpp

View workflow job for this annotation

GitHub Actions / Build with GCC / Linux release with adapters

misc-const-correctness

variable 'validRowsInReusedResult' of type 'SelectivityVector' can be declared 'const'
toElementRows<ArrayVector>(newNumElements, rows, flatArray.get());

// Transformed elements.
VectorPtr newElements;

auto elementToTopLevelRows = getElementToTopLevelRows(
newNumElements, rows, flatArray.get(), context.pool());

// Loop over lambda functions and apply these to elements of the base array.
// In most cases there will be only one function and the loop will run once.
auto lambdaIt = args[1]->asUnchecked<FunctionVector>()->iterator(&rows);
while (auto entry = lambdaIt.next()) {
auto elementRows = toElementRows<ArrayVector>(
newNumElements, *entry.rows, flatArray.get());
auto wrapCapture = toWrapCapture<ArrayVector>(
newNumElements, entry.callable, *entry.rows, flatArray);

entry.callable->apply(
elementRows,
&validRowsInReusedResult,
wrapCapture,
&context,
lambdaArgs,
elementToTopLevelRows,
&newElements);
}

// Set nulls for rows not present in 'rows'.
BufferPtr newNulls = addNullsForUnselectedRows(flatArray, rows);

VectorPtr localResult = std::make_shared<ArrayVector>(
flatArray->pool(),
outputType,
std::move(newNulls),
rows.end(),
flatArray->offsets(),
flatArray->sizes(),
newElements);
context.moveOrCopyResult(localResult, rows, result);
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the signatures, call the super function signatures() and add the spark specified function signature

return {
// Signature 1: array(T), function(T, U) -> array(U) (element only)
exec::FunctionSignatureBuilder()
.typeVariable("T")
.typeVariable("U")
.returnType("array(U)")
.argumentType("array(T)")
.argumentType("function(T, U)")
.build(),
// Signature 2: array(T), function(T, integer, U) -> array(U) (element +
// index). Spark uses IntegerType (32-bit) for the index parameter.
exec::FunctionSignatureBuilder()
.typeVariable("T")
.typeVariable("U")
.returnType("array(U)")
.argumentType("array(T)")
.argumentType("function(T, integer, U)")
.build()};
}

private:
/// Creates an index vector where each element contains its position within
/// its respective array. For example, if we have arrays [a, b] and [c, d, e],
/// the index vector will be [0, 1, 0, 1, 2].
/// Spark uses IntegerType (32-bit) for the index.
static VectorPtr createIndexVector(
const std::shared_ptr<ArrayVector>& flatArray,
vector_size_t numElements,
exec::EvalCtx& context) {
auto* pool = context.pool();
auto indexVector =
BaseVector::create<FlatVector<int32_t>>(INTEGER(), numElements, pool);

auto* rawOffsets = flatArray->rawOffsets();
auto* rawSizes = flatArray->rawSizes();
auto* rawNulls = flatArray->rawNulls();
auto* rawIndices = indexVector->mutableRawValues();

for (vector_size_t row = 0; row < flatArray->size(); ++row) {
// Skip null arrays.
if (rawNulls && bits::isBitNull(rawNulls, row)) {
continue;
}
auto offset = rawOffsets[row];
auto size = rawSizes[row];
for (vector_size_t i = 0; i < size; ++i) {
rawIndices[offset + i] = i;
}
}

return indexVector;
}
};

} // namespace

VELOX_DECLARE_VECTOR_FUNCTION_WITH_METADATA(
udf_spark_transform,
SparkTransformFunction::signatures(),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
std::make_unique<SparkTransformFunction>());

} // namespace facebook::velox::functions::sparksql
9 changes: 8 additions & 1 deletion velox/functions/sparksql/registration/RegisterArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ namespace facebook::velox::functions {
// vector function definition.
// Higher order functions.
void registerSparkArrayFunctions(const std::string& prefix) {
VELOX_REGISTER_VECTOR_FUNCTION(udf_transform, prefix + "transform");
VELOX_REGISTER_VECTOR_FUNCTION(udf_reduce, prefix + "aggregate");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_constructor, prefix + "array");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_contains, prefix + "array_contains");
Expand All @@ -52,6 +51,12 @@ void registerSparkArrayFunctions(const std::string& prefix) {
}

namespace sparksql {

// Register Spark-specific transform with index support.
void registerSparkTransform(const std::string& prefix) {
VELOX_REGISTER_VECTOR_FUNCTION(udf_spark_transform, prefix + "transform");
}

template <typename T>
inline void registerArrayConcatFunction(const std::string& prefix) {
registerFunction<
Expand Down Expand Up @@ -216,6 +221,8 @@ void registerArrayFunctions(const std::string& prefix) {
registerArrayRemoveFunctions(prefix);
registerArrayPrependFunctions(prefix);
registerSparkArrayFunctions(prefix);
// Register Spark-specific transform with index support.
registerSparkTransform(prefix);
// Register array sort functions.
exec::registerStatefulVectorFunction(
prefix + "array_sort", arraySortSignatures(true), makeArraySortAsc);
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ add_executable(
StringTest.cpp
StringToMapTest.cpp
ToJsonTest.cpp
TransformTest.cpp
UnBase64Test.cpp
UnscaledValueFunctionTest.cpp
UpperLowerTest.cpp
Expand Down
159 changes: 159 additions & 0 deletions velox/functions/sparksql/tests/TransformTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

namespace facebook::velox::functions::sparksql::test {
namespace {

using namespace facebook::velox::test;

class TransformTest : public SparkFunctionBaseTest {
protected:
void testTransform(
const std::string& expression,
const VectorPtr& input,
const VectorPtr& expected) {
auto result = evaluate(expression, makeRowVector({input}));
assertEqualVectors(expected, result);
}
};

// Test transform with element-only lambda: transform(array, x -> x * 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

End with .

TEST_F(TransformTest, elementOnly) {
auto input = makeArrayVectorFromJson<int64_t>({
"[1, 2, 3]",
"[4, 5]",
"[6]",
"[]",
"null",
});

auto expected = makeArrayVectorFromJson<int64_t>({
"[2, 4, 6]",
"[8, 10]",
"[12]",
"[]",
"null",
});

testTransform("transform(c0, x -> x * 2)", input, expected);
}

// Test transform with element-only lambda containing null elements.
TEST_F(TransformTest, elementOnlyWithNulls) {
auto input = makeArrayVectorFromJson<int64_t>({
"[1, null, 3]",
"[null, null]",
"[4, 5, null]",
});

auto expected = makeArrayVectorFromJson<int64_t>({
"[2, null, 6]",
"[null, null]",
"[8, 10, null]",
});

testTransform("transform(c0, x -> x * 2)", input, expected);
}

// Test transform with element + index lambda: transform(array, (x, i) -> x + i)
TEST_F(TransformTest, elementWithIndex) {
auto input = makeArrayVectorFromJson<int64_t>({
"[10, 20, 30]",
"[100, 200]",
"[1000]",
"[]",
"null",
});

// Expected: element + index
// [10+0, 20+1, 30+2] = [10, 21, 32]
// [100+0, 200+1] = [100, 201]
// [1000+0] = [1000]
// [] = []
// null = null
auto expected = makeArrayVectorFromJson<int64_t>({
"[10, 21, 32]",
"[100, 201]",
"[1000]",
"[]",
"null",
});

testTransform("transform(c0, (x, i) -> add(x, i))", input, expected);
}

// Test transform with index-only usage: transform(array, (x, i) -> i)
TEST_F(TransformTest, indexOnly) {
auto input = makeArrayVectorFromJson<int64_t>({
"[100, 200, 300]",
"[1, 2]",
"[42]",
"[]",
});

// Expected: just the indices (as INTEGER/int32, matching Spark's IntegerType)
auto expected = makeArrayVectorFromJson<int32_t>({
"[0, 1, 2]",
"[0, 1]",
"[0]",
"[]",
});

testTransform("transform(c0, (x, i) -> i)", input, expected);
}

// Test transform with index and null elements.
TEST_F(TransformTest, elementWithIndexAndNulls) {
auto input = makeArrayVectorFromJson<int64_t>({
"[1, null, 3]",
"[null, 5]",
});

// Expected: element + index, null elements stay null
// [1+0, null, 3+2] = [1, null, 5]
// [null, 5+1] = [null, 6]
auto expected = makeArrayVectorFromJson<int64_t>({
"[1, null, 5]",
"[null, 6]",
});

testTransform("transform(c0, (x, i) -> add(x, i))", input, expected);
}

// Test transform with index multiplication.
TEST_F(TransformTest, indexMultiplication) {
auto input = makeArrayVectorFromJson<int64_t>({
"[1, 2, 3, 4]",
"[10, 20]",
});

// Expected: element * (index + 1)
// [1*1, 2*2, 3*3, 4*4] = [1, 4, 9, 16]
// [10*1, 20*2] = [10, 40]
auto expected = makeArrayVectorFromJson<int64_t>({
"[1, 4, 9, 16]",
"[10, 40]",
});

testTransform(
"transform(c0, (x, i) -> multiply(x, add(i, 1)))", input, expected);
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test
Loading