Skip to content

Commit 1775307

Browse files
Guilherme Kunigamifacebook-github-bot
authored andcommitted
test: Add dependency injection for remote function thrift client (#16231)
Summary: Refactored the remote function client to use dependency injection, enabling proper mocking in tests: 1. Added `IRemoteFunctionClient` abstract interface in ThriftClient.h with `invokeFunction()` method 2. Added `ThriftRemoteFunctionClient` default implementation that wraps the actual thrift client 3. Added `RemoteFunctionClientFactory` type alias for creating clients 4. Added optional `clientFactory` field to `RemoteThriftVectorFunctionMetadata` 5. Modified `RemoteThriftFunction` to use injected factory (with fallback to default) 6. Added `MockRemoteFunctionTest` test fixture demonstrating mock client usage 7. Added 3 new tests using mock clients instead of real thrift server This allows tests to inject mock clients instead of requiring a real thrift server, making tests faster and more isolated. Differential Revision: D92213049
1 parent a539ae3 commit 1775307

File tree

4 files changed

+190
-3
lines changed

4 files changed

+190
-3
lines changed

velox/functions/remote/client/Remote.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ class RemoteThriftFunction : public RemoteVectorFunction {
3333
const RemoteThriftVectorFunctionMetadata& metadata)
3434
: RemoteVectorFunction(functionName, inputArgs, metadata),
3535
location_(metadata.location),
36-
thriftClient_(getThriftClient(location_, &eventBase_)) {}
36+
client_(createClient(metadata)) {}
3737

3838
std::unique_ptr<remote::RemoteFunctionResponse> invokeRemoteFunction(
3939
const remote::RemoteFunctionRequest& request) const override {
4040
auto remoteResponse = std::make_unique<remote::RemoteFunctionResponse>();
41-
thriftClient_->sync_invokeFunction(*remoteResponse, request);
41+
client_->invokeFunction(*remoteResponse, request);
4242
return remoteResponse;
4343
}
4444

@@ -47,10 +47,18 @@ class RemoteThriftFunction : public RemoteVectorFunction {
4747
}
4848

4949
private:
50+
std::unique_ptr<IRemoteFunctionClient> createClient(
51+
const RemoteThriftVectorFunctionMetadata& metadata) {
52+
if (metadata.clientFactory) {
53+
return metadata.clientFactory(metadata.location, &eventBase_);
54+
}
55+
return getDefaultRemoteFunctionClient(metadata.location, &eventBase_);
56+
}
57+
5058
folly::SocketAddress location_;
5159
folly::EventBase eventBase_;
5260

53-
std::unique_ptr<RemoteFunctionClient> thriftClient_;
61+
std::unique_ptr<IRemoteFunctionClient> client_;
5462
};
5563

5664
std::shared_ptr<exec::VectorFunction> createRemoteFunction(

velox/functions/remote/client/Remote.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <folly/SocketAddress.h>
2020
#include "velox/functions/remote/client/RemoteVectorFunction.h"
21+
#include "velox/functions/remote/client/ThriftClient.h"
2122

2223
namespace facebook::velox::functions {
2324

@@ -27,6 +28,11 @@ struct RemoteThriftVectorFunctionMetadata
2728
/// Note that this can hold a network location (ip/port pair) or a unix domain
2829
/// socket path (see SocketAddress::makeFromPath()).
2930
folly::SocketAddress location;
31+
32+
/// Optional factory for creating remote function clients. If not set, the
33+
/// default thrift client factory is used. This enables dependency injection
34+
/// for testing with mock clients.
35+
RemoteFunctionClientFactory clientFactory;
3036
};
3137

3238
/// Registers a new remote function. It will use the meatadata defined in

velox/functions/remote/client/ThriftClient.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,51 @@ namespace facebook::velox::functions {
2424
using RemoteFunctionClient =
2525
apache::thrift::Client<remote::RemoteFunctionService>;
2626

27+
/// Abstract interface for the remote function client, enabling dependency
28+
/// injection and mocking in tests.
29+
class IRemoteFunctionClient {
30+
public:
31+
virtual ~IRemoteFunctionClient() = default;
32+
33+
/// Invokes the remote function synchronously.
34+
virtual void invokeFunction(
35+
remote::RemoteFunctionResponse& response,
36+
const remote::RemoteFunctionRequest& request) = 0;
37+
};
38+
39+
/// Default implementation that wraps the actual thrift client.
40+
class ThriftRemoteFunctionClient : public IRemoteFunctionClient {
41+
public:
42+
explicit ThriftRemoteFunctionClient(
43+
std::unique_ptr<RemoteFunctionClient> client)
44+
: client_(std::move(client)) {}
45+
46+
void invokeFunction(
47+
remote::RemoteFunctionResponse& response,
48+
const remote::RemoteFunctionRequest& request) override {
49+
client_->sync_invokeFunction(response, request);
50+
}
51+
52+
private:
53+
std::unique_ptr<RemoteFunctionClient> client_;
54+
};
55+
56+
/// Factory function type for creating remote function clients.
57+
/// Parameters: location (socket address), eventBase (for async operations)
58+
/// Returns: A unique_ptr to an IRemoteFunctionClient implementation.
59+
using RemoteFunctionClientFactory = std::function<std::unique_ptr<
60+
IRemoteFunctionClient>(folly::SocketAddress, folly::EventBase*)>;
61+
2762
std::unique_ptr<RemoteFunctionClient> getThriftClient(
2863
folly::SocketAddress location,
2964
folly::EventBase* eventBase);
3065

66+
/// Default factory that creates ThriftRemoteFunctionClient instances.
67+
inline std::unique_ptr<IRemoteFunctionClient> getDefaultRemoteFunctionClient(
68+
folly::SocketAddress location,
69+
folly::EventBase* eventBase) {
70+
return std::make_unique<ThriftRemoteFunctionClient>(
71+
getThriftClient(location, eventBase));
72+
}
73+
3174
} // namespace facebook::velox::functions

velox/functions/remote/client/tests/RemoteFunctionTest.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@
2929
#include "velox/functions/prestosql/StringFunctions.h"
3030
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
3131
#include "velox/functions/remote/client/Remote.h"
32+
#include "velox/functions/remote/client/ThriftClient.h"
33+
#include "velox/functions/remote/if/GetSerde.h"
3234
#include "velox/functions/remote/if/gen-cpp2/RemoteFunctionService.h"
3335
#include "velox/functions/remote/server/RemoteFunctionService.h"
3436
#include "velox/functions/remote/utils/RemoteFunctionServiceProvider.h"
3537
#include "velox/serializers/PrestoSerializer.h"
38+
#include "velox/vector/VectorStream.h"
3639

3740
using ::apache::thrift::ThriftServer;
3841
using ::facebook::velox::test::assertEqualVectors;
@@ -263,6 +266,133 @@ TEST_P(RemoteFunctionTest, connectionError) {
263266
}
264267
}
265268

269+
/// Mock implementation of IRemoteFunctionClient for testing without a real
270+
/// thrift server.
271+
class MockRemoteFunctionClient : public IRemoteFunctionClient {
272+
public:
273+
MOCK_METHOD(
274+
void,
275+
invokeFunction,
276+
(remote::RemoteFunctionResponse&, const remote::RemoteFunctionRequest&),
277+
(override));
278+
};
279+
280+
/// Test fixture that uses mock clients instead of a real thrift server running
281+
/// in the same process.
282+
class MockRemoteFunctionTest : public functions::test::FunctionBaseTest {
283+
public:
284+
void SetUp() override {
285+
mockClient_ =
286+
std::make_shared<testing::NiceMock<MockRemoteFunctionClient>>();
287+
}
288+
289+
void TearDown() override {
290+
mockClient_.reset();
291+
}
292+
293+
/// Registers a remote function with a mock client factory.
294+
void registerMockRemoteFunction(
295+
const std::string& name,
296+
std::vector<exec::FunctionSignaturePtr> signatures) {
297+
RemoteThriftVectorFunctionMetadata metadata;
298+
metadata.serdeFormat = remote::PageFormat::PRESTO_PAGE;
299+
// Location doesn't matter since we're using a mock client.
300+
metadata.location = folly::SocketAddress("127.0.0.1", 12345);
301+
302+
// Capture mockClient_ by value (shared_ptr) so it stays alive.
303+
auto mockClientPtr = mockClient_;
304+
metadata.clientFactory = [mockClientPtr](
305+
folly::SocketAddress, folly::EventBase*) {
306+
// We need this level of indirection because clientFactory returns a
307+
// unique_ptr so we wouldn't have a reference to the mock client.
308+
class MockClientWrapper : public IRemoteFunctionClient {
309+
public:
310+
explicit MockClientWrapper(
311+
std::shared_ptr<MockRemoteFunctionClient> mock)
312+
: mock_(std::move(mock)) {}
313+
314+
void invokeFunction(
315+
remote::RemoteFunctionResponse& response,
316+
const remote::RemoteFunctionRequest& request) override {
317+
mock_->invokeFunction(response, request);
318+
}
319+
320+
private:
321+
std::shared_ptr<MockRemoteFunctionClient> mock_;
322+
};
323+
return std::make_unique<MockClientWrapper>(mockClientPtr);
324+
};
325+
326+
registerRemoteFunction(name, std::move(signatures), metadata);
327+
}
328+
329+
/// Helper to set up a mock response with a serialized result vector.
330+
void setMockResponse(
331+
remote::RemoteFunctionResponse& response,
332+
const RowVectorPtr& resultVector) {
333+
auto serde = getSerde(remote::PageFormat::PRESTO_PAGE);
334+
auto result = response.result();
335+
result->rowCount() = resultVector->size();
336+
result->pageFormat() = remote::PageFormat::PRESTO_PAGE;
337+
result->payload() = rowVectorToIOBuf(resultVector, *pool(), serde.get());
338+
}
339+
340+
protected:
341+
std::shared_ptr<testing::NiceMock<MockRemoteFunctionClient>> mockClient_;
342+
};
343+
344+
TEST_F(MockRemoteFunctionTest, mockClientIsCalled) {
345+
// Register a mock remote function.
346+
auto signatures = {exec::FunctionSignatureBuilder()
347+
.returnType("bigint")
348+
.argumentType("bigint")
349+
.argumentType("bigint")
350+
.build()};
351+
registerMockRemoteFunction("mock_plus", signatures);
352+
353+
// Set up the mock to return a valid response.
354+
EXPECT_CALL(*mockClient_, invokeFunction(testing::_, testing::_))
355+
.WillOnce([this](
356+
remote::RemoteFunctionResponse& response,
357+
const remote::RemoteFunctionRequest& /*request*/) {
358+
// Return doubled values: input {1,2,3} + {1,2,3} = {2,4,6}
359+
auto resultVector = makeRowVector({makeFlatVector<int64_t>({2, 4, 6})});
360+
setMockResponse(response, resultVector);
361+
});
362+
363+
auto inputVector = makeFlatVector<int64_t>({1, 2, 3});
364+
auto results = evaluate<SimpleVector<int64_t>>(
365+
"mock_plus(c0, c0)", makeRowVector({inputVector}));
366+
367+
// Verify the mock returned our expected values.
368+
auto expected = makeFlatVector<int64_t>({2, 4, 6});
369+
assertEqualVectors(expected, results);
370+
}
371+
372+
TEST_F(MockRemoteFunctionTest, mockClientThrowsException) {
373+
// Register a mock remote function.
374+
auto signatures = {exec::FunctionSignatureBuilder()
375+
.returnType("bigint")
376+
.argumentType("bigint")
377+
.build()};
378+
registerMockRemoteFunction("mock_throwing", signatures);
379+
380+
// Set up the mock to throw an exception.
381+
EXPECT_CALL(*mockClient_, invokeFunction(testing::_, testing::_))
382+
.WillOnce([](remote::RemoteFunctionResponse&,
383+
const remote::RemoteFunctionRequest&) {
384+
throw std::runtime_error("Mock connection error");
385+
});
386+
387+
auto inputVector = makeFlatVector<int64_t>({1, 2, 3});
388+
389+
// Verify the exception is propagated.
390+
VELOX_ASSERT_THROW(
391+
evaluate<SimpleVector<int64_t>>(
392+
"mock_throwing(c0)", makeRowVector({inputVector})),
393+
"Mock connection error");
394+
}
395+
266396
VELOX_INSTANTIATE_TEST_SUITE_P(
267397
RemoteFunctionTestFixture,
268398
RemoteFunctionTest,

0 commit comments

Comments
 (0)