Skip to content

Commit c28d2c1

Browse files
committed
Refactor: Extract base TransformFunctionBase class
Address reviewer feedback to share code between Presto and Spark transform: - Create TransformFunctionBase in velox/functions/lib/ with virtual addIndexVector() - Presto TransformFunction inherits base directly (no override needed) - Spark TransformFunction overrides addIndexVector() to add index parameter - Spark signatures() calls base and adds the (T, integer, U) signature All 13 transform tests pass (6 Spark + 7 Presto).
1 parent 7656e89 commit c28d2c1

File tree

3 files changed

+153
-165
lines changed

3 files changed

+153
-165
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "velox/expression/Expr.h"
19+
#include "velox/expression/VectorFunction.h"
20+
#include "velox/functions/lib/LambdaFunctionUtil.h"
21+
#include "velox/functions/lib/RowsTranslationUtil.h"
22+
#include "velox/vector/FunctionVector.h"
23+
24+
namespace facebook::velox::functions {
25+
26+
/// Base class for array transform functions.
27+
/// Subclasses can override addIndexVector() to provide additional lambda
28+
/// arguments (e.g., element index for Spark's transform function).
29+
class TransformFunctionBase : public exec::VectorFunction {
30+
public:
31+
void apply(
32+
const SelectivityVector& rows,
33+
std::vector<VectorPtr>& args,
34+
const TypePtr& outputType,
35+
exec::EvalCtx& context,
36+
VectorPtr& result) const override {
37+
VELOX_CHECK_EQ(args.size(), 2);
38+
39+
// Flatten input array.
40+
exec::LocalDecodedVector arrayDecoder(context, *args[0], rows);
41+
auto& decodedArray = *arrayDecoder.get();
42+
43+
auto flatArray = flattenArray(rows, args[0], decodedArray);
44+
auto newNumElements = flatArray->elements()->size();
45+
46+
std::vector<VectorPtr> lambdaArgs = {flatArray->elements()};
47+
48+
// Allow subclasses to add additional lambda arguments (e.g., index vector).
49+
addIndexVector(args, flatArray, newNumElements, context, lambdaArgs);
50+
51+
SelectivityVector validRowsInReusedResult =
52+
toElementRows<ArrayVector>(newNumElements, rows, flatArray.get());
53+
54+
// Transformed elements.
55+
VectorPtr newElements;
56+
57+
auto elementToTopLevelRows = getElementToTopLevelRows(
58+
newNumElements, rows, flatArray.get(), context.pool());
59+
60+
// Loop over lambda functions and apply these to elements of the base array.
61+
// In most cases there will be only one function and the loop will run once.
62+
auto it = args[1]->asUnchecked<FunctionVector>()->iterator(&rows);
63+
while (auto entry = it.next()) {
64+
auto elementRows = toElementRows<ArrayVector>(
65+
newNumElements, *entry.rows, flatArray.get());
66+
auto wrapCapture = toWrapCapture<ArrayVector>(
67+
newNumElements, entry.callable, *entry.rows, flatArray);
68+
69+
entry.callable->apply(
70+
elementRows,
71+
&validRowsInReusedResult,
72+
wrapCapture,
73+
&context,
74+
lambdaArgs,
75+
elementToTopLevelRows,
76+
&newElements);
77+
}
78+
79+
// Set nulls for rows not present in 'rows'.
80+
BufferPtr newNulls = addNullsForUnselectedRows(flatArray, rows);
81+
82+
VectorPtr localResult = std::make_shared<ArrayVector>(
83+
flatArray->pool(),
84+
outputType,
85+
std::move(newNulls),
86+
rows.end(),
87+
flatArray->offsets(),
88+
flatArray->sizes(),
89+
newElements);
90+
context.moveOrCopyResult(localResult, rows, result);
91+
}
92+
93+
/// Returns the base signature: array(T), function(T, U) -> array(U).
94+
/// Subclasses can call this and add additional signatures.
95+
static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
96+
return {exec::FunctionSignatureBuilder()
97+
.typeVariable("T")
98+
.typeVariable("U")
99+
.returnType("array(U)")
100+
.argumentType("array(T)")
101+
.argumentType("function(T, U)")
102+
.build()};
103+
}
104+
105+
protected:
106+
/// Override this method to add additional arguments to the lambda.
107+
/// Default implementation does nothing (element-only transform).
108+
/// @param args The original function arguments.
109+
/// @param flatArray The flattened input array.
110+
/// @param numElements Total number of elements across all arrays.
111+
/// @param context The evaluation context.
112+
/// @param lambdaArgs Output vector to append additional arguments to.
113+
virtual void addIndexVector(
114+
const std::vector<VectorPtr>& /*args*/,
115+
const ArrayVectorPtr& /*flatArray*/,
116+
vector_size_t /*numElements*/,
117+
exec::EvalCtx& /*context*/,
118+
std::vector<VectorPtr>& /*lambdaArgs*/) const {}
119+
};
120+
121+
} // namespace facebook::velox::functions

velox/functions/prestosql/Transform.cpp

Lines changed: 5 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -13,88 +13,17 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#include "velox/expression/Expr.h"
17-
#include "velox/expression/VectorFunction.h"
18-
#include "velox/functions/lib/LambdaFunctionUtil.h"
19-
#include "velox/functions/lib/RowsTranslationUtil.h"
20-
#include "velox/vector/FunctionVector.h"
16+
#include "velox/functions/lib/TransformFunctionBase.h"
2117

2218
namespace facebook::velox::functions {
2319
namespace {
2420

2521
// See documentation at https://prestodb.io/docs/current/functions/array.html
26-
class TransformFunction : public exec::VectorFunction {
27-
public:
28-
void apply(
29-
const SelectivityVector& rows,
30-
std::vector<VectorPtr>& args,
31-
const TypePtr& outputType,
32-
exec::EvalCtx& context,
33-
VectorPtr& result) const override {
34-
VELOX_CHECK_EQ(args.size(), 2);
35-
36-
// Flatten input array.
37-
exec::LocalDecodedVector arrayDecoder(context, *args[0], rows);
38-
auto& decodedArray = *arrayDecoder.get();
39-
40-
auto flatArray = flattenArray(rows, args[0], decodedArray);
41-
42-
std::vector<VectorPtr> lambdaArgs = {flatArray->elements()};
43-
auto newNumElements = flatArray->elements()->size();
44-
45-
SelectivityVector validRowsInReusedResult =
46-
toElementRows<ArrayVector>(newNumElements, rows, flatArray.get());
47-
48-
// transformed elements
49-
VectorPtr newElements;
50-
51-
auto elementToTopLevelRows = getElementToTopLevelRows(
52-
newNumElements, rows, flatArray.get(), context.pool());
53-
54-
// loop over lambda functions and apply these to elements of the base array;
55-
// in most cases there will be only one function and the loop will run once
56-
auto it = args[1]->asUnchecked<FunctionVector>()->iterator(&rows);
57-
while (auto entry = it.next()) {
58-
auto elementRows = toElementRows<ArrayVector>(
59-
newNumElements, *entry.rows, flatArray.get());
60-
auto wrapCapture = toWrapCapture<ArrayVector>(
61-
newNumElements, entry.callable, *entry.rows, flatArray);
62-
63-
entry.callable->apply(
64-
elementRows,
65-
&validRowsInReusedResult,
66-
wrapCapture,
67-
&context,
68-
lambdaArgs,
69-
elementToTopLevelRows,
70-
&newElements);
71-
}
72-
73-
// Set nulls for rows not present in 'rows'.
74-
BufferPtr newNulls = addNullsForUnselectedRows(flatArray, rows);
75-
76-
VectorPtr localResult = std::make_shared<ArrayVector>(
77-
flatArray->pool(),
78-
outputType,
79-
std::move(newNulls),
80-
rows.end(),
81-
flatArray->offsets(),
82-
flatArray->sizes(),
83-
newElements);
84-
context.moveOrCopyResult(localResult, rows, result);
85-
}
86-
87-
static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
88-
// array(T), function(T, U) -> array(U)
89-
return {exec::FunctionSignatureBuilder()
90-
.typeVariable("T")
91-
.typeVariable("U")
92-
.returnType("array(U)")
93-
.argumentType("array(T)")
94-
.argumentType("function(T, U)")
95-
.build()};
96-
}
22+
class TransformFunction : public TransformFunctionBase {
23+
// Inherits apply() and signatures() from base class.
24+
// No additional lambda arguments needed for Presto.
9725
};
26+
9827
} // namespace
9928

10029
/// transform is null preserving for the array. But since an

velox/functions/sparksql/Transform.cpp

Lines changed: 27 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
*/
1616

1717
#include "velox/common/base/BitUtil.h"
18-
#include "velox/expression/Expr.h"
19-
#include "velox/expression/VectorFunction.h"
20-
#include "velox/functions/lib/LambdaFunctionUtil.h"
21-
#include "velox/functions/lib/RowsTranslationUtil.h"
22-
#include "velox/vector/FunctionVector.h"
18+
#include "velox/functions/lib/TransformFunctionBase.h"
2319

2420
namespace facebook::velox::functions::sparksql {
2521
namespace {
@@ -30,100 +26,42 @@ namespace {
3026
///
3127
/// See Spark documentation:
3228
/// https://spark.apache.org/docs/latest/api/sql/index.html#transform
33-
class TransformFunction : public exec::VectorFunction {
29+
class TransformFunction : public TransformFunctionBase {
3430
public:
35-
void apply(
36-
const SelectivityVector& rows,
37-
std::vector<VectorPtr>& args,
38-
const TypePtr& outputType,
39-
exec::EvalCtx& context,
40-
VectorPtr& result) const override {
41-
VELOX_CHECK_EQ(args.size(), 2);
42-
43-
// Flatten input array.
44-
exec::LocalDecodedVector arrayDecoder(context, *args[0], rows);
45-
auto& decodedArray = *arrayDecoder.get();
46-
47-
auto flatArray = flattenArray(rows, args[0], decodedArray);
48-
auto newNumElements = flatArray->elements()->size();
31+
/// Returns both base signature and Spark-specific signature with index.
32+
static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
33+
auto sigs = TransformFunctionBase::signatures();
34+
// Add Spark-specific signature: array(T), function(T, integer, U) ->
35+
// array(U). Spark uses IntegerType (32-bit) for the index parameter.
36+
sigs.push_back(
37+
exec::FunctionSignatureBuilder()
38+
.typeVariable("T")
39+
.typeVariable("U")
40+
.returnType("array(U)")
41+
.argumentType("array(T)")
42+
.argumentType("function(T, integer, U)")
43+
.build());
44+
return sigs;
45+
}
4946

50-
// Determine if we need to pass index to the lambda.
47+
protected:
48+
/// Adds index vector to lambda arguments if the lambda expects it.
49+
void addIndexVector(
50+
const std::vector<VectorPtr>& args,
51+
const ArrayVectorPtr& flatArray,
52+
vector_size_t numElements,
53+
exec::EvalCtx& context,
54+
std::vector<VectorPtr>& lambdaArgs) const override {
5155
// Check the lambda function type to see if it expects 2 input arguments.
5256
// function(T, U) has 2 children (input T, output U) -> 1 input arg.
5357
// function(T, integer, U) has 3 children (input T, index integer, output U)
5458
// -> 2 input args.
5559
auto functionType = args[1]->type();
5660
bool withIndex = functionType->size() == 3;
5761

58-
std::vector<VectorPtr> lambdaArgs = {flatArray->elements()};
59-
60-
// If lambda expects index, create index vector.
6162
if (withIndex) {
62-
auto indexVector = createIndexVector(flatArray, newNumElements, context);
63-
lambdaArgs.push_back(indexVector);
64-
}
65-
66-
SelectivityVector validRowsInReusedResult =
67-
toElementRows<ArrayVector>(newNumElements, rows, flatArray.get());
68-
69-
// Transformed elements.
70-
VectorPtr newElements;
71-
72-
auto elementToTopLevelRows = getElementToTopLevelRows(
73-
newNumElements, rows, flatArray.get(), context.pool());
74-
75-
// Loop over lambda functions and apply these to elements of the base array.
76-
// In most cases there will be only one function and the loop will run once.
77-
auto lambdaIt = args[1]->asUnchecked<FunctionVector>()->iterator(&rows);
78-
while (auto entry = lambdaIt.next()) {
79-
auto elementRows = toElementRows<ArrayVector>(
80-
newNumElements, *entry.rows, flatArray.get());
81-
auto wrapCapture = toWrapCapture<ArrayVector>(
82-
newNumElements, entry.callable, *entry.rows, flatArray);
83-
84-
entry.callable->apply(
85-
elementRows,
86-
&validRowsInReusedResult,
87-
wrapCapture,
88-
&context,
89-
lambdaArgs,
90-
elementToTopLevelRows,
91-
&newElements);
63+
lambdaArgs.push_back(createIndexVector(flatArray, numElements, context));
9264
}
93-
94-
// Set nulls for rows not present in 'rows'.
95-
BufferPtr newNulls = addNullsForUnselectedRows(flatArray, rows);
96-
97-
VectorPtr localResult = std::make_shared<ArrayVector>(
98-
flatArray->pool(),
99-
outputType,
100-
std::move(newNulls),
101-
rows.end(),
102-
flatArray->offsets(),
103-
flatArray->sizes(),
104-
newElements);
105-
context.moveOrCopyResult(localResult, rows, result);
106-
}
107-
108-
static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
109-
return {
110-
// Signature 1: array(T), function(T, U) -> array(U) (element only)
111-
exec::FunctionSignatureBuilder()
112-
.typeVariable("T")
113-
.typeVariable("U")
114-
.returnType("array(U)")
115-
.argumentType("array(T)")
116-
.argumentType("function(T, U)")
117-
.build(),
118-
// Signature 2: array(T), function(T, integer, U) -> array(U) (element +
119-
// index). Spark uses IntegerType (32-bit) for the index parameter.
120-
exec::FunctionSignatureBuilder()
121-
.typeVariable("T")
122-
.typeVariable("U")
123-
.returnType("array(U)")
124-
.argumentType("array(T)")
125-
.argumentType("function(T, integer, U)")
126-
.build()};
12765
}
12866

12967
private:
@@ -132,7 +70,7 @@ class TransformFunction : public exec::VectorFunction {
13270
/// the index vector will be [0, 1, 0, 1, 2].
13371
/// Spark uses IntegerType (32-bit) for the index.
13472
static VectorPtr createIndexVector(
135-
const std::shared_ptr<ArrayVector>& flatArray,
73+
const ArrayVectorPtr& flatArray,
13674
vector_size_t numElements,
13775
exec::EvalCtx& context) {
13876
auto* pool = context.pool();

0 commit comments

Comments
 (0)