-
Notifications
You must be signed in to change notification settings - Fork 1.4k
feat(sparksql): Add transform with index parameter support #16211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| /* | ||
| * Copyright (c) Facebook, Inc. and its affiliates. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include "velox/common/base/BitUtil.h" | ||
| #include "velox/expression/Expr.h" | ||
| #include "velox/expression/VectorFunction.h" | ||
| #include "velox/functions/lib/LambdaFunctionUtil.h" | ||
| #include "velox/functions/lib/RowsTranslationUtil.h" | ||
| #include "velox/vector/FunctionVector.h" | ||
|
|
||
| namespace facebook::velox::functions::sparksql { | ||
| namespace { | ||
|
|
||
| /// Spark's transform function supporting both signatures: | ||
| /// 1. transform(array, x -> expr) - element only | ||
| /// 2. transform(array, (x, i) -> expr) - element + index (Spark-specific) | ||
| /// | ||
| /// See Spark documentation: | ||
| /// https://spark.apache.org/docs/latest/api/sql/index.html#transform | ||
| class SparkTransformFunction : public exec::VectorFunction { | ||
|
||
| public: | ||
| void apply( | ||
|
||
| const SelectivityVector& rows, | ||
| std::vector<VectorPtr>& args, | ||
|
Check warning on line 37 in velox/functions/sparksql/Transform.cpp
|
||
| const TypePtr& outputType, | ||
| exec::EvalCtx& context, | ||
| VectorPtr& result) const override { | ||
| VELOX_CHECK_EQ(args.size(), 2); | ||
|
|
||
| // Flatten input array. | ||
| exec::LocalDecodedVector arrayDecoder(context, *args[0], rows); | ||
| auto& decodedArray = *arrayDecoder.get(); | ||
|
|
||
| auto flatArray = flattenArray(rows, args[0], decodedArray); | ||
| auto newNumElements = flatArray->elements()->size(); | ||
|
|
||
| // Determine if we need to pass index to the lambda. | ||
| // Check the lambda function type to see if it expects 2 input arguments. | ||
| // function(T, U) has 2 children (input T, output U) -> 1 input arg. | ||
| // function(T, integer, U) has 3 children (input T, index integer, output U) | ||
| // -> 2 input args. | ||
| auto functionType = args[1]->type(); | ||
| bool withIndex = functionType->size() == 3; | ||
|
|
||
| std::vector<VectorPtr> lambdaArgs = {flatArray->elements()}; | ||
|
|
||
| // If lambda expects index, create index vector. | ||
| if (withIndex) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like the only change with presto TransformFunction is here, am I right? So I would prefer a base TransformFunction and add a virtual addIndexVectors(lambdaArgs) with default empty implementation, |
||
| auto indexVector = createIndexVector(flatArray, newNumElements, context); | ||
| lambdaArgs.push_back(indexVector); | ||
| } | ||
|
|
||
| SelectivityVector validRowsInReusedResult = | ||
| toElementRows<ArrayVector>(newNumElements, rows, flatArray.get()); | ||
|
|
||
| // Transformed elements. | ||
| VectorPtr newElements; | ||
|
|
||
| auto elementToTopLevelRows = getElementToTopLevelRows( | ||
| newNumElements, rows, flatArray.get(), context.pool()); | ||
|
|
||
| // Loop over lambda functions and apply these to elements of the base array. | ||
| // In most cases there will be only one function and the loop will run once. | ||
| auto lambdaIt = args[1]->asUnchecked<FunctionVector>()->iterator(&rows); | ||
| while (auto entry = lambdaIt.next()) { | ||
| auto elementRows = toElementRows<ArrayVector>( | ||
| newNumElements, *entry.rows, flatArray.get()); | ||
| auto wrapCapture = toWrapCapture<ArrayVector>( | ||
| newNumElements, entry.callable, *entry.rows, flatArray); | ||
|
|
||
| entry.callable->apply( | ||
| elementRows, | ||
| &validRowsInReusedResult, | ||
| wrapCapture, | ||
| &context, | ||
| lambdaArgs, | ||
| elementToTopLevelRows, | ||
| &newElements); | ||
| } | ||
|
|
||
| // Set nulls for rows not present in 'rows'. | ||
| BufferPtr newNulls = addNullsForUnselectedRows(flatArray, rows); | ||
|
|
||
| VectorPtr localResult = std::make_shared<ArrayVector>( | ||
| flatArray->pool(), | ||
| outputType, | ||
| std::move(newNulls), | ||
| rows.end(), | ||
| flatArray->offsets(), | ||
| flatArray->sizes(), | ||
| newElements); | ||
| context.moveOrCopyResult(localResult, rows, result); | ||
| } | ||
|
|
||
| static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() { | ||
|
||
| return { | ||
| // Signature 1: array(T), function(T, U) -> array(U) (element only) | ||
| exec::FunctionSignatureBuilder() | ||
| .typeVariable("T") | ||
| .typeVariable("U") | ||
| .returnType("array(U)") | ||
| .argumentType("array(T)") | ||
| .argumentType("function(T, U)") | ||
| .build(), | ||
| // Signature 2: array(T), function(T, integer, U) -> array(U) (element + | ||
| // index). Spark uses IntegerType (32-bit) for the index parameter. | ||
| exec::FunctionSignatureBuilder() | ||
| .typeVariable("T") | ||
| .typeVariable("U") | ||
| .returnType("array(U)") | ||
| .argumentType("array(T)") | ||
| .argumentType("function(T, integer, U)") | ||
| .build()}; | ||
| } | ||
|
|
||
| private: | ||
| /// Creates an index vector where each element contains its position within | ||
| /// its respective array. For example, if we have arrays [a, b] and [c, d, e], | ||
| /// the index vector will be [0, 1, 0, 1, 2]. | ||
| /// Spark uses IntegerType (32-bit) for the index. | ||
| static VectorPtr createIndexVector( | ||
| const std::shared_ptr<ArrayVector>& flatArray, | ||
| vector_size_t numElements, | ||
| exec::EvalCtx& context) { | ||
| auto* pool = context.pool(); | ||
| auto indexVector = | ||
| BaseVector::create<FlatVector<int32_t>>(INTEGER(), numElements, pool); | ||
|
|
||
| auto* rawOffsets = flatArray->rawOffsets(); | ||
| auto* rawSizes = flatArray->rawSizes(); | ||
| auto* rawNulls = flatArray->rawNulls(); | ||
| auto* rawIndices = indexVector->mutableRawValues(); | ||
|
|
||
| for (vector_size_t row = 0; row < flatArray->size(); ++row) { | ||
| // Skip null arrays. | ||
| if (rawNulls && bits::isBitNull(rawNulls, row)) { | ||
| continue; | ||
| } | ||
| auto offset = rawOffsets[row]; | ||
| auto size = rawSizes[row]; | ||
| for (vector_size_t i = 0; i < size; ++i) { | ||
| rawIndices[offset + i] = i; | ||
| } | ||
| } | ||
|
|
||
| return indexVector; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| VELOX_DECLARE_VECTOR_FUNCTION_WITH_METADATA( | ||
| udf_spark_transform, | ||
| SparkTransformFunction::signatures(), | ||
| exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(), | ||
| std::make_unique<SparkTransformFunction>()); | ||
|
|
||
| } // namespace facebook::velox::functions::sparksql | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| /* | ||
| * Copyright (c) Facebook, Inc. and its affiliates. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include "velox/common/base/tests/GTestUtils.h" | ||
| #include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" | ||
|
|
||
| namespace facebook::velox::functions::sparksql::test { | ||
| namespace { | ||
|
|
||
| using namespace facebook::velox::test; | ||
|
|
||
| class TransformTest : public SparkFunctionBaseTest { | ||
| protected: | ||
| void testTransform( | ||
| const std::string& expression, | ||
| const VectorPtr& input, | ||
| const VectorPtr& expected) { | ||
| auto result = evaluate(expression, makeRowVector({input})); | ||
| assertEqualVectors(expected, result); | ||
| } | ||
| }; | ||
|
|
||
| // Test transform with element-only lambda: transform(array, x -> x * 2) | ||
|
||
| TEST_F(TransformTest, elementOnly) { | ||
| auto input = makeArrayVectorFromJson<int64_t>({ | ||
| "[1, 2, 3]", | ||
| "[4, 5]", | ||
| "[6]", | ||
| "[]", | ||
| "null", | ||
| }); | ||
|
|
||
| auto expected = makeArrayVectorFromJson<int64_t>({ | ||
| "[2, 4, 6]", | ||
| "[8, 10]", | ||
| "[12]", | ||
| "[]", | ||
| "null", | ||
| }); | ||
|
|
||
| testTransform("transform(c0, x -> x * 2)", input, expected); | ||
| } | ||
|
|
||
| // Test transform with element-only lambda containing null elements. | ||
| TEST_F(TransformTest, elementOnlyWithNulls) { | ||
| auto input = makeArrayVectorFromJson<int64_t>({ | ||
| "[1, null, 3]", | ||
| "[null, null]", | ||
| "[4, 5, null]", | ||
| }); | ||
|
|
||
| auto expected = makeArrayVectorFromJson<int64_t>({ | ||
| "[2, null, 6]", | ||
| "[null, null]", | ||
| "[8, 10, null]", | ||
| }); | ||
|
|
||
| testTransform("transform(c0, x -> x * 2)", input, expected); | ||
| } | ||
|
|
||
| // Test transform with element + index lambda: transform(array, (x, i) -> x + i) | ||
| TEST_F(TransformTest, elementWithIndex) { | ||
| auto input = makeArrayVectorFromJson<int64_t>({ | ||
| "[10, 20, 30]", | ||
| "[100, 200]", | ||
| "[1000]", | ||
| "[]", | ||
| "null", | ||
| }); | ||
|
|
||
| // Expected: element + index | ||
| // [10+0, 20+1, 30+2] = [10, 21, 32] | ||
| // [100+0, 200+1] = [100, 201] | ||
| // [1000+0] = [1000] | ||
| // [] = [] | ||
| // null = null | ||
| auto expected = makeArrayVectorFromJson<int64_t>({ | ||
| "[10, 21, 32]", | ||
| "[100, 201]", | ||
| "[1000]", | ||
| "[]", | ||
| "null", | ||
| }); | ||
|
|
||
| testTransform("transform(c0, (x, i) -> add(x, i))", input, expected); | ||
| } | ||
|
|
||
| // Test transform with index-only usage: transform(array, (x, i) -> i) | ||
| TEST_F(TransformTest, indexOnly) { | ||
| auto input = makeArrayVectorFromJson<int64_t>({ | ||
| "[100, 200, 300]", | ||
| "[1, 2]", | ||
| "[42]", | ||
| "[]", | ||
| }); | ||
|
|
||
| // Expected: just the indices (as INTEGER/int32, matching Spark's IntegerType) | ||
| auto expected = makeArrayVectorFromJson<int32_t>({ | ||
| "[0, 1, 2]", | ||
| "[0, 1]", | ||
| "[0]", | ||
| "[]", | ||
| }); | ||
|
|
||
| testTransform("transform(c0, (x, i) -> i)", input, expected); | ||
| } | ||
|
|
||
| // Test transform with index and null elements. | ||
| TEST_F(TransformTest, elementWithIndexAndNulls) { | ||
| auto input = makeArrayVectorFromJson<int64_t>({ | ||
| "[1, null, 3]", | ||
| "[null, 5]", | ||
| }); | ||
|
|
||
| // Expected: element + index, null elements stay null | ||
| // [1+0, null, 3+2] = [1, null, 5] | ||
| // [null, 5+1] = [null, 6] | ||
| auto expected = makeArrayVectorFromJson<int64_t>({ | ||
| "[1, null, 5]", | ||
| "[null, 6]", | ||
| }); | ||
|
|
||
| testTransform("transform(c0, (x, i) -> add(x, i))", input, expected); | ||
| } | ||
|
|
||
| // Test transform with index multiplication. | ||
| TEST_F(TransformTest, indexMultiplication) { | ||
| auto input = makeArrayVectorFromJson<int64_t>({ | ||
| "[1, 2, 3, 4]", | ||
| "[10, 20]", | ||
| }); | ||
|
|
||
| // Expected: element * (index + 1) | ||
| // [1*1, 2*2, 3*3, 4*4] = [1, 4, 9, 16] | ||
| // [10*1, 20*2] = [10, 40] | ||
| auto expected = makeArrayVectorFromJson<int64_t>({ | ||
| "[1, 4, 9, 16]", | ||
| "[10, 40]", | ||
| }); | ||
|
|
||
| testTransform( | ||
| "transform(c0, (x, i) -> multiply(x, add(i, 1)))", input, expected); | ||
| } | ||
|
|
||
| } // namespace | ||
| } // namespace facebook::velox::functions::sparksql::test | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
supporting -> supports