Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions velox/functions/remote/client/Remote.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class RemoteThriftFunction : public RemoteVectorFunction {
const RemoteThriftVectorFunctionMetadata& metadata)
: RemoteVectorFunction(functionName, inputArgs, metadata),
location_(metadata.location),
thriftClient_(getThriftClient(location_, &eventBase_)) {}
client_(createClient(metadata)) {}

std::unique_ptr<remote::RemoteFunctionResponse> invokeRemoteFunction(
const remote::RemoteFunctionRequest& request) const override {
auto remoteResponse = std::make_unique<remote::RemoteFunctionResponse>();
thriftClient_->sync_invokeFunction(*remoteResponse, request);
client_->invokeFunction(*remoteResponse, request);
return remoteResponse;
}

Expand All @@ -47,10 +47,18 @@ class RemoteThriftFunction : public RemoteVectorFunction {
}

private:
std::unique_ptr<IRemoteFunctionClient> createClient(
const RemoteThriftVectorFunctionMetadata& metadata) {
if (metadata.clientFactory) {
return metadata.clientFactory(metadata.location, &eventBase_);
}
return getDefaultRemoteFunctionClient(metadata.location, &eventBase_);
}

folly::SocketAddress location_;
folly::EventBase eventBase_;

std::unique_ptr<RemoteFunctionClient> thriftClient_;
std::unique_ptr<IRemoteFunctionClient> client_;
};

std::shared_ptr<exec::VectorFunction> createRemoteFunction(
Expand Down
6 changes: 6 additions & 0 deletions velox/functions/remote/client/Remote.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <folly/SocketAddress.h>
#include "velox/functions/remote/client/RemoteVectorFunction.h"
#include "velox/functions/remote/client/ThriftClient.h"

namespace facebook::velox::functions {

Expand All @@ -27,6 +28,11 @@ struct RemoteThriftVectorFunctionMetadata
/// Note that this can hold a network location (ip/port pair) or a unix domain
/// socket path (see SocketAddress::makeFromPath()).
folly::SocketAddress location;

/// Optional factory for creating remote function clients. If not set, the
/// default thrift client factory is used. This enables dependency injection
/// for testing with mock clients.
RemoteFunctionClientFactory clientFactory;
};

/// Registers a new remote function. It will use the meatadata defined in
Expand Down
43 changes: 43 additions & 0 deletions velox/functions/remote/client/ThriftClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,51 @@ namespace facebook::velox::functions {
using RemoteFunctionClient =
apache::thrift::Client<remote::RemoteFunctionService>;

/// Abstract interface for the remote function client, enabling dependency
/// injection and mocking in tests.
class IRemoteFunctionClient {
public:
virtual ~IRemoteFunctionClient() = default;

/// Invokes the remote function synchronously.
virtual void invokeFunction(
remote::RemoteFunctionResponse& response,
const remote::RemoteFunctionRequest& request) = 0;
};

/// Default implementation that wraps the actual thrift client.
class ThriftRemoteFunctionClient : public IRemoteFunctionClient {
public:
explicit ThriftRemoteFunctionClient(
std::unique_ptr<RemoteFunctionClient> client)
: client_(std::move(client)) {}

void invokeFunction(
remote::RemoteFunctionResponse& response,
const remote::RemoteFunctionRequest& request) override {
client_->sync_invokeFunction(response, request);
}

private:
std::unique_ptr<RemoteFunctionClient> client_;
};

/// Factory function type for creating remote function clients.
/// Parameters: location (socket address), eventBase (for async operations)
/// Returns: A unique_ptr to an IRemoteFunctionClient implementation.
using RemoteFunctionClientFactory = std::function<std::unique_ptr<
IRemoteFunctionClient>(folly::SocketAddress, folly::EventBase*)>;

std::unique_ptr<RemoteFunctionClient> getThriftClient(
folly::SocketAddress location,
folly::EventBase* eventBase);

/// Default factory that creates ThriftRemoteFunctionClient instances.
inline std::unique_ptr<IRemoteFunctionClient> getDefaultRemoteFunctionClient(
folly::SocketAddress location,
folly::EventBase* eventBase) {
return std::make_unique<ThriftRemoteFunctionClient>(
getThriftClient(location, eventBase));
}

} // namespace facebook::velox::functions
130 changes: 130 additions & 0 deletions velox/functions/remote/client/tests/RemoteFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@
#include "velox/functions/prestosql/StringFunctions.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/remote/client/Remote.h"
#include "velox/functions/remote/client/ThriftClient.h"
#include "velox/functions/remote/if/GetSerde.h"
#include "velox/functions/remote/if/gen-cpp2/RemoteFunctionService.h"
#include "velox/functions/remote/server/RemoteFunctionService.h"
#include "velox/functions/remote/utils/RemoteFunctionServiceProvider.h"
#include "velox/serializers/PrestoSerializer.h"
#include "velox/vector/VectorStream.h"

using ::apache::thrift::ThriftServer;
using ::facebook::velox::test::assertEqualVectors;
Expand Down Expand Up @@ -263,6 +266,133 @@ TEST_P(RemoteFunctionTest, connectionError) {
}
}

/// Mock implementation of IRemoteFunctionClient for testing without a real
/// thrift server.
class MockRemoteFunctionClient : public IRemoteFunctionClient {
public:
MOCK_METHOD(
void,
invokeFunction,
(remote::RemoteFunctionResponse&, const remote::RemoteFunctionRequest&),
(override));
};

/// Test fixture that uses mock clients instead of a real thrift server running
/// in the same process.
class MockRemoteFunctionTest : public functions::test::FunctionBaseTest {
public:
void SetUp() override {
mockClient_ =
std::make_shared<testing::NiceMock<MockRemoteFunctionClient>>();
}

void TearDown() override {
mockClient_.reset();
}

/// Registers a remote function with a mock client factory.
void registerMockRemoteFunction(
const std::string& name,
std::vector<exec::FunctionSignaturePtr> signatures) {
RemoteThriftVectorFunctionMetadata metadata;
metadata.serdeFormat = remote::PageFormat::PRESTO_PAGE;
// Location doesn't matter since we're using a mock client.
metadata.location = folly::SocketAddress("127.0.0.1", 12345);

// Capture mockClient_ by value (shared_ptr) so it stays alive.
auto mockClientPtr = mockClient_;
metadata.clientFactory = [mockClientPtr](
folly::SocketAddress, folly::EventBase*) {
// We need this level of indirection because clientFactory returns a
// unique_ptr so we wouldn't have a reference to the mock client.
class MockClientWrapper : public IRemoteFunctionClient {
public:
explicit MockClientWrapper(
std::shared_ptr<MockRemoteFunctionClient> mock)
: mock_(std::move(mock)) {}

void invokeFunction(
remote::RemoteFunctionResponse& response,
const remote::RemoteFunctionRequest& request) override {
mock_->invokeFunction(response, request);
}

private:
std::shared_ptr<MockRemoteFunctionClient> mock_;
};
return std::make_unique<MockClientWrapper>(mockClientPtr);
};

registerRemoteFunction(name, std::move(signatures), metadata);
}

/// Helper to set up a mock response with a serialized result vector.
void setMockResponse(
remote::RemoteFunctionResponse& response,
const RowVectorPtr& resultVector) {
auto serde = getSerde(remote::PageFormat::PRESTO_PAGE);
auto result = response.result();
result->rowCount() = resultVector->size();
result->pageFormat() = remote::PageFormat::PRESTO_PAGE;
result->payload() = rowVectorToIOBuf(resultVector, *pool(), serde.get());
}

protected:
std::shared_ptr<testing::NiceMock<MockRemoteFunctionClient>> mockClient_;
};

TEST_F(MockRemoteFunctionTest, mockClientIsCalled) {
// Register a mock remote function.
auto signatures = {exec::FunctionSignatureBuilder()
.returnType("bigint")
.argumentType("bigint")
.argumentType("bigint")
.build()};
registerMockRemoteFunction("mock_plus", signatures);

// Set up the mock to return a valid response.
EXPECT_CALL(*mockClient_, invokeFunction(testing::_, testing::_))
.WillOnce([this](
remote::RemoteFunctionResponse& response,
const remote::RemoteFunctionRequest& /*request*/) {
// Return doubled values: input {1,2,3} + {1,2,3} = {2,4,6}
auto resultVector = makeRowVector({makeFlatVector<int64_t>({2, 4, 6})});
setMockResponse(response, resultVector);
});

auto inputVector = makeFlatVector<int64_t>({1, 2, 3});
auto results = evaluate<SimpleVector<int64_t>>(
"mock_plus(c0, c0)", makeRowVector({inputVector}));

// Verify the mock returned our expected values.
auto expected = makeFlatVector<int64_t>({2, 4, 6});
assertEqualVectors(expected, results);
}

TEST_F(MockRemoteFunctionTest, mockClientThrowsException) {
// Register a mock remote function.
auto signatures = {exec::FunctionSignatureBuilder()
.returnType("bigint")
.argumentType("bigint")
.build()};
registerMockRemoteFunction("mock_throwing", signatures);

// Set up the mock to throw an exception.
EXPECT_CALL(*mockClient_, invokeFunction(testing::_, testing::_))
.WillOnce([](remote::RemoteFunctionResponse&,
const remote::RemoteFunctionRequest&) {
throw std::runtime_error("Mock connection error");
});

auto inputVector = makeFlatVector<int64_t>({1, 2, 3});

// Verify the exception is propagated.
VELOX_ASSERT_THROW(
evaluate<SimpleVector<int64_t>>(
"mock_throwing(c0)", makeRowVector({inputVector})),
"Mock connection error");
}

VELOX_INSTANTIATE_TEST_SUITE_P(
RemoteFunctionTestFixture,
RemoteFunctionTest,
Expand Down
Loading