diff --git a/velox/functions/lib/TransformFunctionBase.h b/velox/functions/lib/TransformFunctionBase.h new file mode 100644 index 000000000000..fd74c577195a --- /dev/null +++ b/velox/functions/lib/TransformFunctionBase.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/Expr.h" +#include "velox/expression/VectorFunction.h" +#include "velox/functions/lib/LambdaFunctionUtil.h" +#include "velox/functions/lib/RowsTranslationUtil.h" +#include "velox/vector/FunctionVector.h" + +namespace facebook::velox::functions { + +/// Base class for array transform functions. +/// Subclasses can override addIndexVector() to provide additional lambda +/// arguments (e.g., element index for Spark's transform function). +class TransformFunctionBase : public exec::VectorFunction { + public: + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + VELOX_CHECK_EQ(args.size(), 2); + + // Flatten input array. + exec::LocalDecodedVector arrayDecoder(context, *args[0], rows); + auto& decodedArray = *arrayDecoder.get(); + + auto flatArray = flattenArray(rows, args[0], decodedArray); + auto newNumElements = flatArray->elements()->size(); + + std::vector lambdaArgs = {flatArray->elements()}; + + // Allow subclasses to add additional lambda arguments (e.g., index vector). + addIndexVector(args, flatArray, newNumElements, context, lambdaArgs); + + SelectivityVector validRowsInReusedResult = + toElementRows(newNumElements, rows, flatArray.get()); + + // Transformed elements. + VectorPtr newElements; + + auto elementToTopLevelRows = getElementToTopLevelRows( + newNumElements, rows, flatArray.get(), context.pool()); + + // Loop over lambda functions and apply these to elements of the base array. + // In most cases there will be only one function and the loop will run once. + auto it = args[1]->asUnchecked()->iterator(&rows); + while (auto entry = it.next()) { + auto elementRows = toElementRows( + newNumElements, *entry.rows, flatArray.get()); + auto wrapCapture = toWrapCapture( + newNumElements, entry.callable, *entry.rows, flatArray); + + entry.callable->apply( + elementRows, + &validRowsInReusedResult, + wrapCapture, + &context, + lambdaArgs, + elementToTopLevelRows, + &newElements); + } + + // Set nulls for rows not present in 'rows'. + BufferPtr newNulls = addNullsForUnselectedRows(flatArray, rows); + + VectorPtr localResult = std::make_shared( + flatArray->pool(), + outputType, + std::move(newNulls), + rows.end(), + flatArray->offsets(), + flatArray->sizes(), + newElements); + context.moveOrCopyResult(localResult, rows, result); + } + + /// Returns the base signature: array(T), function(T, U) -> array(U). + /// Subclasses can call this and add additional signatures. + static std::vector> signatures() { + return {exec::FunctionSignatureBuilder() + .typeVariable("T") + .typeVariable("U") + .returnType("array(U)") + .argumentType("array(T)") + .argumentType("function(T, U)") + .build()}; + } + + protected: + /// Override this method to add additional arguments to the lambda. + /// Default implementation does nothing (element-only transform). + /// @param args The original function arguments. + /// @param flatArray The flattened input array. + /// @param numElements Total number of elements across all arrays. + /// @param context The evaluation context. + /// @param lambdaArgs Output vector to append additional arguments to. + virtual void addIndexVector( + const std::vector& /*args*/, + const ArrayVectorPtr& /*flatArray*/, + vector_size_t /*numElements*/, + exec::EvalCtx& /*context*/, + std::vector& /*lambdaArgs*/) const {} +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/Transform.cpp b/velox/functions/prestosql/Transform.cpp index 3028a2c256d1..123635c87e31 100644 --- a/velox/functions/prestosql/Transform.cpp +++ b/velox/functions/prestosql/Transform.cpp @@ -13,88 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "velox/expression/Expr.h" -#include "velox/expression/VectorFunction.h" -#include "velox/functions/lib/LambdaFunctionUtil.h" -#include "velox/functions/lib/RowsTranslationUtil.h" -#include "velox/vector/FunctionVector.h" +#include "velox/functions/lib/TransformFunctionBase.h" namespace facebook::velox::functions { namespace { // See documentation at https://prestodb.io/docs/current/functions/array.html -class TransformFunction : public exec::VectorFunction { - public: - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& outputType, - exec::EvalCtx& context, - VectorPtr& result) const override { - VELOX_CHECK_EQ(args.size(), 2); - - // Flatten input array. - exec::LocalDecodedVector arrayDecoder(context, *args[0], rows); - auto& decodedArray = *arrayDecoder.get(); - - auto flatArray = flattenArray(rows, args[0], decodedArray); - - std::vector lambdaArgs = {flatArray->elements()}; - auto newNumElements = flatArray->elements()->size(); - - SelectivityVector validRowsInReusedResult = - toElementRows(newNumElements, rows, flatArray.get()); - - // transformed elements - VectorPtr newElements; - - auto elementToTopLevelRows = getElementToTopLevelRows( - newNumElements, rows, flatArray.get(), context.pool()); - - // loop over lambda functions and apply these to elements of the base array; - // in most cases there will be only one function and the loop will run once - auto it = args[1]->asUnchecked()->iterator(&rows); - while (auto entry = it.next()) { - auto elementRows = toElementRows( - newNumElements, *entry.rows, flatArray.get()); - auto wrapCapture = toWrapCapture( - newNumElements, entry.callable, *entry.rows, flatArray); - - entry.callable->apply( - elementRows, - &validRowsInReusedResult, - wrapCapture, - &context, - lambdaArgs, - elementToTopLevelRows, - &newElements); - } - - // Set nulls for rows not present in 'rows'. - BufferPtr newNulls = addNullsForUnselectedRows(flatArray, rows); - - VectorPtr localResult = std::make_shared( - flatArray->pool(), - outputType, - std::move(newNulls), - rows.end(), - flatArray->offsets(), - flatArray->sizes(), - newElements); - context.moveOrCopyResult(localResult, rows, result); - } - - static std::vector> signatures() { - // array(T), function(T, U) -> array(U) - return {exec::FunctionSignatureBuilder() - .typeVariable("T") - .typeVariable("U") - .returnType("array(U)") - .argumentType("array(T)") - .argumentType("function(T, U)") - .build()}; - } +class TransformFunction : public TransformFunctionBase { + // Inherits apply() and signatures() from base class. + // No additional lambda arguments needed for Presto. }; + } // namespace /// transform is null preserving for the array. But since an diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index 3ae83a3dbce1..d2c519c0d901 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -17,6 +17,7 @@ velox_add_library( velox_functions_spark_impl ArrayGetFunction.cpp ArraySort.cpp + Transform.cpp CharVarcharUtils.cpp Comparisons.cpp ConcatWs.cpp diff --git a/velox/functions/sparksql/Transform.cpp b/velox/functions/sparksql/Transform.cpp new file mode 100644 index 000000000000..ede46e9be13b --- /dev/null +++ b/velox/functions/sparksql/Transform.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/BitUtil.h" +#include "velox/functions/lib/TransformFunctionBase.h" + +namespace facebook::velox::functions::sparksql { +namespace { + +/// Spark's transform function supports both signatures: +/// 1. transform(array, x -> expr) - element only +/// 2. transform(array, (x, i) -> expr) - element + index (Spark-specific) +/// +/// See Spark documentation: +/// https://spark.apache.org/docs/latest/api/sql/index.html#transform +class TransformFunction : public TransformFunctionBase { + public: + /// Returns both base signature and Spark-specific signature with index. + static std::vector> signatures() { + auto sigs = TransformFunctionBase::signatures(); + // Add Spark-specific signature: array(T), function(T, integer, U) -> + // array(U). Spark uses IntegerType (32-bit) for the index parameter. + sigs.push_back( + exec::FunctionSignatureBuilder() + .typeVariable("T") + .typeVariable("U") + .returnType("array(U)") + .argumentType("array(T)") + .argumentType("function(T, integer, U)") + .build()); + return sigs; + } + + protected: + /// Adds index vector to lambda arguments if the lambda expects it. + void addIndexVector( + const std::vector& args, + const ArrayVectorPtr& flatArray, + vector_size_t numElements, + exec::EvalCtx& context, + std::vector& lambdaArgs) const override { + // Check the lambda function type to see if it expects 2 input arguments. + // function(T, U) has 2 children (input T, output U) -> 1 input arg. + // function(T, integer, U) has 3 children (input T, index integer, output U) + // -> 2 input args. + auto functionType = args[1]->type(); + bool withIndex = functionType->size() == 3; + + if (withIndex) { + lambdaArgs.push_back(createIndexVector(flatArray, numElements, context)); + } + } + + private: + /// Creates an index vector where each element contains its position within + /// its respective array. For example, if we have arrays [a, b] and [c, d, e], + /// the index vector will be [0, 1, 0, 1, 2]. + /// Spark uses IntegerType (32-bit) for the index. + static VectorPtr createIndexVector( + const ArrayVectorPtr& flatArray, + vector_size_t numElements, + exec::EvalCtx& context) { + auto* pool = context.pool(); + auto indexVector = + BaseVector::create>(INTEGER(), numElements, pool); + + auto* rawOffsets = flatArray->rawOffsets(); + auto* rawSizes = flatArray->rawSizes(); + auto* rawNulls = flatArray->rawNulls(); + auto* rawIndices = indexVector->mutableRawValues(); + + for (vector_size_t row = 0; row < flatArray->size(); ++row) { + // Skip null arrays. + if (rawNulls && bits::isBitNull(rawNulls, row)) { + continue; + } + auto offset = rawOffsets[row]; + auto size = rawSizes[row]; + for (vector_size_t i = 0; i < size; ++i) { + rawIndices[offset + i] = i; + } + } + + return indexVector; + } +}; + +} // namespace + +VELOX_DECLARE_VECTOR_FUNCTION_WITH_METADATA( + udf_spark_transform, + TransformFunction::signatures(), + exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(), + std::make_unique()); + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/registration/RegisterArray.cpp b/velox/functions/sparksql/registration/RegisterArray.cpp index 9706ffc39f85..3c8d89497447 100644 --- a/velox/functions/sparksql/registration/RegisterArray.cpp +++ b/velox/functions/sparksql/registration/RegisterArray.cpp @@ -35,7 +35,6 @@ namespace facebook::velox::functions { // vector function definition. // Higher order functions. void registerSparkArrayFunctions(const std::string& prefix) { - VELOX_REGISTER_VECTOR_FUNCTION(udf_transform, prefix + "transform"); VELOX_REGISTER_VECTOR_FUNCTION(udf_reduce, prefix + "aggregate"); VELOX_REGISTER_VECTOR_FUNCTION(udf_array_constructor, prefix + "array"); VELOX_REGISTER_VECTOR_FUNCTION(udf_array_contains, prefix + "array_contains"); @@ -52,6 +51,12 @@ void registerSparkArrayFunctions(const std::string& prefix) { } namespace sparksql { + +// Register Spark-specific transform with index support. +void registerSparkTransform(const std::string& prefix) { + VELOX_REGISTER_VECTOR_FUNCTION(udf_spark_transform, prefix + "transform"); +} + template inline void registerArrayConcatFunction(const std::string& prefix) { registerFunction< @@ -216,6 +221,8 @@ void registerArrayFunctions(const std::string& prefix) { registerArrayRemoveFunctions(prefix); registerArrayPrependFunctions(prefix); registerSparkArrayFunctions(prefix); + // Register Spark-specific transform with index support. + registerSparkTransform(prefix); // Register array sort functions. exec::registerStatefulVectorFunction( prefix + "array_sort", arraySortSignatures(true), makeArraySortAsc); diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index ad1b5833d4e7..6a4173a7f0af 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -72,6 +72,7 @@ add_executable( StringTest.cpp StringToMapTest.cpp ToJsonTest.cpp + TransformTest.cpp UnBase64Test.cpp UnscaledValueFunctionTest.cpp UpperLowerTest.cpp diff --git a/velox/functions/sparksql/tests/TransformTest.cpp b/velox/functions/sparksql/tests/TransformTest.cpp new file mode 100644 index 000000000000..b623f17f6f2e --- /dev/null +++ b/velox/functions/sparksql/tests/TransformTest.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +namespace facebook::velox::functions::sparksql::test { +namespace { + +using namespace facebook::velox::test; + +class TransformTest : public SparkFunctionBaseTest { + protected: + void testTransform( + const std::string& expression, + const VectorPtr& input, + const VectorPtr& expected) { + auto result = evaluate(expression, makeRowVector({input})); + assertEqualVectors(expected, result); + } +}; + +// Test transform with element-only lambda: transform(array, x -> x * 2). +TEST_F(TransformTest, elementOnly) { + auto input = makeArrayVectorFromJson({ + "[1, 2, 3]", + "[4, 5]", + "[6]", + "[]", + "null", + }); + + auto expected = makeArrayVectorFromJson({ + "[2, 4, 6]", + "[8, 10]", + "[12]", + "[]", + "null", + }); + + testTransform("transform(c0, x -> x * 2)", input, expected); +} + +// Test transform with element-only lambda containing null elements. +TEST_F(TransformTest, elementOnlyWithNulls) { + auto input = makeArrayVectorFromJson({ + "[1, null, 3]", + "[null, null]", + "[4, 5, null]", + }); + + auto expected = makeArrayVectorFromJson({ + "[2, null, 6]", + "[null, null]", + "[8, 10, null]", + }); + + testTransform("transform(c0, x -> x * 2)", input, expected); +} + +// Test transform with element + index lambda: transform(array, (x, i) -> x + i) +TEST_F(TransformTest, elementWithIndex) { + auto input = makeArrayVectorFromJson({ + "[10, 20, 30]", + "[100, 200]", + "[1000]", + "[]", + "null", + }); + + // Expected: element + index + // [10+0, 20+1, 30+2] = [10, 21, 32] + // [100+0, 200+1] = [100, 201] + // [1000+0] = [1000] + // [] = [] + // null = null + auto expected = makeArrayVectorFromJson({ + "[10, 21, 32]", + "[100, 201]", + "[1000]", + "[]", + "null", + }); + + testTransform("transform(c0, (x, i) -> add(x, i))", input, expected); +} + +// Test transform with index-only usage: transform(array, (x, i) -> i) +TEST_F(TransformTest, indexOnly) { + auto input = makeArrayVectorFromJson({ + "[100, 200, 300]", + "[1, 2]", + "[42]", + "[]", + }); + + // Expected: just the indices (as INTEGER/int32, matching Spark's IntegerType) + auto expected = makeArrayVectorFromJson({ + "[0, 1, 2]", + "[0, 1]", + "[0]", + "[]", + }); + + testTransform("transform(c0, (x, i) -> i)", input, expected); +} + +// Test transform with index and null elements. +TEST_F(TransformTest, elementWithIndexAndNulls) { + auto input = makeArrayVectorFromJson({ + "[1, null, 3]", + "[null, 5]", + }); + + // Expected: element + index, null elements stay null + // [1+0, null, 3+2] = [1, null, 5] + // [null, 5+1] = [null, 6] + auto expected = makeArrayVectorFromJson({ + "[1, null, 5]", + "[null, 6]", + }); + + testTransform("transform(c0, (x, i) -> add(x, i))", input, expected); +} + +// Test transform with index multiplication. +TEST_F(TransformTest, indexMultiplication) { + auto input = makeArrayVectorFromJson({ + "[1, 2, 3, 4]", + "[10, 20]", + }); + + // Expected: element * (index + 1) + // [1*1, 2*2, 3*3, 4*4] = [1, 4, 9, 16] + // [10*1, 20*2] = [10, 40] + auto expected = makeArrayVectorFromJson({ + "[1, 4, 9, 16]", + "[10, 40]", + }); + + testTransform( + "transform(c0, (x, i) -> multiply(x, add(i, 1)))", input, expected); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test