diff --git a/velox/functions/remote/client/Remote.cpp b/velox/functions/remote/client/Remote.cpp index b09406a628ba..7e0b57af689c 100644 --- a/velox/functions/remote/client/Remote.cpp +++ b/velox/functions/remote/client/Remote.cpp @@ -16,12 +16,29 @@ #include "velox/functions/remote/client/Remote.h" +#include +#include +#include +#include + +#include #include +#include #include "velox/functions/remote/client/RemoteVectorFunction.h" #include "velox/functions/remote/client/ThriftClient.h" #include "velox/functions/remote/if/GetSerde.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionServiceAsyncClient.h" +DEFINE_int32( + remote_function_retry_count, + 3, + "Number of retries for remote function calls on transport errors"); + +DEFINE_int32( + remote_function_retry_max_backoff_sec, + 8, + "Maximum exponential backoff in seconds for remote function retries"); + namespace facebook::velox::functions { namespace { @@ -32,14 +49,42 @@ class RemoteThriftFunction : public RemoteVectorFunction { const std::vector& inputArgs, const RemoteThriftVectorFunctionMetadata& metadata) : RemoteVectorFunction(functionName, inputArgs, metadata), + functionName_(functionName), location_(metadata.location), - thriftClient_(getThriftClient(location_, &eventBase_)) {} + client_(createClient(metadata)) { + VLOG(1) << "Created RemoteThriftFunction '" << functionName_ << "' for " + << location_.describe(); + } std::unique_ptr invokeRemoteFunction( const remote::RemoteFunctionRequest& request) const override { auto remoteResponse = std::make_unique(); - thriftClient_->sync_invokeFunction(*remoteResponse, request); - return remoteResponse; + + int retryCount = 0; + int expIntervalSec = 1; + + while (true) { + try { + VLOG(2) << "Invoking remote function '" << functionName_ + << "' (socket=" << location_.describe() << ")"; + + client_->invokeFunction(*remoteResponse, request); + + VLOG(2) << "Remote function '" << functionName_ << "' call succeeded"; + return remoteResponse; + + } catch (const apache::thrift::transport::TTransportException& e) { + if (!handleRetryableError(e.what(), retryCount, expIntervalSec)) { + throw; + } + } catch (const folly::AsyncSocketException& e) { + std::string errorMsg = fmt::format( + "{} (type={})", e.what(), static_cast(e.getType())); + if (!handleRetryableError(errorMsg, retryCount, expIntervalSec)) { + throw; + } + } + } } std::string remoteLocationToString() const override { @@ -47,10 +92,63 @@ class RemoteThriftFunction : public RemoteVectorFunction { } private: - folly::SocketAddress location_; - folly::EventBase eventBase_; + std::unique_ptr createClient( + const RemoteThriftVectorFunctionMetadata& metadata) { + if (metadata.clientFactory) { + clientFactory_ = metadata.clientFactory; + return clientFactory_(metadata.location, &eventBase_); + } + clientFactory_ = getDefaultRemoteFunctionClient; + return clientFactory_(metadata.location, &eventBase_); + } - std::unique_ptr thriftClient_; + // Handles retryable errors with exponential backoff. + // Returns true if retry should continue, false if retries exhausted. + bool handleRetryableError( + const std::string& errorMsg, + int& retryCount, + int& expIntervalSec) const { + LOG(ERROR) << "Transport error in remote function '" << functionName_ + << "': " << errorMsg << " (attempt=" << (retryCount + 1) << "/" + << (FLAGS_remote_function_retry_count + 1) << ")"; + + if (retryCount < FLAGS_remote_function_retry_count) { + reconnectClient(); + sleepWithJitter(expIntervalSec); + expIntervalSec = std::min( + expIntervalSec * 2, FLAGS_remote_function_retry_max_backoff_sec); + ++retryCount; + return true; + } + + LOG(ERROR) << "Remote function '" << functionName_ << "' call failed after " + << FLAGS_remote_function_retry_count << " retries"; + return false; + } + + void reconnectClient() const { + LOG(WARNING) << "Reconnecting thrift client for '" << functionName_ + << "' to " << location_.describe(); + client_ = clientFactory_(location_, &eventBase_); + } + + void sleepWithJitter(int expIntervalSec) const { + static thread_local std::mt19937 rng(std::random_device{}()); + // Use range [0.5, expIntervalSec + 0.5) to ensure meaningful backoff + std::uniform_real_distribution dist(0.5, expIntervalSec + 0.5); + auto sleepIntervalSec = static_cast(dist(rng)); + + LOG(INFO) << "Sleeping for " << sleepIntervalSec + << " seconds before retry for '" << functionName_ << "'"; + /* sleep override: intentional backoff for retry logic */ + std::this_thread::sleep_for(std::chrono::seconds(sleepIntervalSec)); + } + + const std::string functionName_; + folly::SocketAddress location_; + mutable folly::EventBase eventBase_; + mutable RemoteFunctionClientFactory clientFactory_; + mutable std::unique_ptr client_; }; std::shared_ptr createRemoteFunction( diff --git a/velox/functions/remote/client/Remote.h b/velox/functions/remote/client/Remote.h index 457d622bd92c..f7b0e49216ad 100644 --- a/velox/functions/remote/client/Remote.h +++ b/velox/functions/remote/client/Remote.h @@ -18,6 +18,7 @@ #include #include "velox/functions/remote/client/RemoteVectorFunction.h" +#include "velox/functions/remote/client/ThriftClient.h" namespace facebook::velox::functions { @@ -27,6 +28,11 @@ struct RemoteThriftVectorFunctionMetadata /// Note that this can hold a network location (ip/port pair) or a unix domain /// socket path (see SocketAddress::makeFromPath()). folly::SocketAddress location; + + /// Optional factory for creating remote function clients. If not set, the + /// default thrift client factory is used. This enables dependency injection + /// for testing with mock clients. + RemoteFunctionClientFactory clientFactory; }; /// Registers a new remote function. It will use the meatadata defined in diff --git a/velox/functions/remote/client/ThriftClient.h b/velox/functions/remote/client/ThriftClient.h index d65db0158978..2009e0800b20 100644 --- a/velox/functions/remote/client/ThriftClient.h +++ b/velox/functions/remote/client/ThriftClient.h @@ -24,8 +24,51 @@ namespace facebook::velox::functions { using RemoteFunctionClient = apache::thrift::Client; +/// Abstract interface for the remote function client, enabling dependency +/// injection and mocking in tests. +class IRemoteFunctionClient { + public: + virtual ~IRemoteFunctionClient() = default; + + /// Invokes the remote function synchronously. + virtual void invokeFunction( + remote::RemoteFunctionResponse& response, + const remote::RemoteFunctionRequest& request) = 0; +}; + +/// Default implementation that wraps the actual thrift client. +class ThriftRemoteFunctionClient : public IRemoteFunctionClient { + public: + explicit ThriftRemoteFunctionClient( + std::unique_ptr client) + : client_(std::move(client)) {} + + void invokeFunction( + remote::RemoteFunctionResponse& response, + const remote::RemoteFunctionRequest& request) override { + client_->sync_invokeFunction(response, request); + } + + private: + std::unique_ptr client_; +}; + +/// Factory function type for creating remote function clients. +/// Parameters: location (socket address), eventBase (for async operations) +/// Returns: A unique_ptr to an IRemoteFunctionClient implementation. +using RemoteFunctionClientFactory = std::function(folly::SocketAddress, folly::EventBase*)>; + std::unique_ptr getThriftClient( folly::SocketAddress location, folly::EventBase* eventBase); +/// Default factory that creates ThriftRemoteFunctionClient instances. +inline std::unique_ptr getDefaultRemoteFunctionClient( + folly::SocketAddress location, + folly::EventBase* eventBase) { + return std::make_unique( + getThriftClient(location, eventBase)); +} + } // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/tests/RemoteFunctionTest.cpp b/velox/functions/remote/client/tests/RemoteFunctionTest.cpp index 712358753895..9e30ba641f95 100644 --- a/velox/functions/remote/client/tests/RemoteFunctionTest.cpp +++ b/velox/functions/remote/client/tests/RemoteFunctionTest.cpp @@ -16,9 +16,11 @@ #include #include +#include #include #include #include +#include #include "velox/common/base/Exceptions.h" #include "velox/common/base/tests/GTestUtils.h" @@ -29,10 +31,16 @@ #include "velox/functions/prestosql/StringFunctions.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" #include "velox/functions/remote/client/Remote.h" +#include "velox/functions/remote/client/ThriftClient.h" +#include "velox/functions/remote/if/GetSerde.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionService.h" #include "velox/functions/remote/server/RemoteFunctionService.h" #include "velox/functions/remote/utils/RemoteFunctionServiceProvider.h" #include "velox/serializers/PrestoSerializer.h" +#include "velox/vector/VectorStream.h" + +DECLARE_int32(remote_function_retry_count); +DECLARE_int32(remote_function_retry_max_backoff_sec); using ::apache::thrift::ThriftServer; using ::facebook::velox::test::assertEqualVectors; @@ -144,6 +152,9 @@ class RemoteFunctionTest OpaqueType::registerSerialization( "Foo", Foo::serialize, Foo::deserialize); } + + private: + gflags::FlagSaver flagSaver_; }; TEST_P(RemoteFunctionTest, simple) { @@ -247,6 +258,9 @@ TEST_P(RemoteFunctionTest, opaque) { } TEST_P(RemoteFunctionTest, connectionError) { + // Disable retries for faster test execution. + FLAGS_remote_function_retry_count = 0; + auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); auto func = [&]() { evaluate>( @@ -263,6 +277,217 @@ TEST_P(RemoteFunctionTest, connectionError) { } } +/// Mock implementation of IRemoteFunctionClient for testing without a real +/// thrift server. +class MockRemoteFunctionClient : public IRemoteFunctionClient { + public: + MOCK_METHOD( + void, + invokeFunction, + (remote::RemoteFunctionResponse&, const remote::RemoteFunctionRequest&), + (override)); +}; + +/// Test fixture that uses mock clients instead of a real thrift server running +/// in the same process. +class MockRemoteFunctionTest : public functions::test::FunctionBaseTest { + public: + void SetUp() override { + mockClient_ = + std::make_shared>(); + } + + void TearDown() override { + mockClient_.reset(); + } + + /// Registers a remote function with a mock client factory. + void registerMockRemoteFunction( + const std::string& name, + std::vector signatures) { + RemoteThriftVectorFunctionMetadata metadata; + metadata.serdeFormat = remote::PageFormat::PRESTO_PAGE; + // Location doesn't matter since we're using a mock client. + metadata.location = folly::SocketAddress("127.0.0.1", 12345); + + // Capture mockClient_ by value (shared_ptr) so it stays alive. + auto mockClientPtr = mockClient_; + metadata.clientFactory = [mockClientPtr]( + folly::SocketAddress, folly::EventBase*) { + // We need this level of indirection because clientFactory returns a + // unique_ptr so we wouldn't have a reference to the mock client. + class MockClientWrapper : public IRemoteFunctionClient { + public: + explicit MockClientWrapper( + std::shared_ptr mock) + : mock_(std::move(mock)) {} + + void invokeFunction( + remote::RemoteFunctionResponse& response, + const remote::RemoteFunctionRequest& request) override { + mock_->invokeFunction(response, request); + } + + private: + std::shared_ptr mock_; + }; + return std::make_unique(mockClientPtr); + }; + + registerRemoteFunction(name, std::move(signatures), metadata); + } + + /// Helper to set up a mock response with a serialized result vector. + void setMockResponse( + remote::RemoteFunctionResponse& response, + const RowVectorPtr& resultVector) { + auto serde = getSerde(remote::PageFormat::PRESTO_PAGE); + auto result = response.result(); + result->rowCount() = resultVector->size(); + result->pageFormat() = remote::PageFormat::PRESTO_PAGE; + result->payload() = rowVectorToIOBuf(resultVector, *pool(), serde.get()); + } + + protected: + // Saves and restores gflags values between tests (e.g., retry settings). + gflags::FlagSaver flagSaver_; + std::shared_ptr> mockClient_; +}; + +TEST_F(MockRemoteFunctionTest, mockClientIsCalled) { + // Register a mock remote function. + auto signatures = {exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .argumentType("bigint") + .build()}; + registerMockRemoteFunction("mock_plus", signatures); + + // Set up the mock to return a valid response. + EXPECT_CALL(*mockClient_, invokeFunction(testing::_, testing::_)) + .WillOnce([this]( + remote::RemoteFunctionResponse& response, + const remote::RemoteFunctionRequest& /*request*/) { + // Return doubled values: input {1,2,3} + {1,2,3} = {2,4,6} + auto resultVector = makeRowVector({makeFlatVector({2, 4, 6})}); + setMockResponse(response, resultVector); + }); + + auto inputVector = makeFlatVector({1, 2, 3}); + auto results = evaluate>( + "mock_plus(c0, c0)", makeRowVector({inputVector})); + + // Verify the mock returned our expected values. + auto expected = makeFlatVector({2, 4, 6}); + assertEqualVectors(expected, results); +} + +TEST_F(MockRemoteFunctionTest, mockClientThrowsException) { + // Register a mock remote function. + auto signatures = {exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .build()}; + registerMockRemoteFunction("mock_throwing", signatures); + + // Set up the mock to throw an exception. + EXPECT_CALL(*mockClient_, invokeFunction(testing::_, testing::_)) + .WillOnce([](remote::RemoteFunctionResponse&, + const remote::RemoteFunctionRequest&) { + throw std::runtime_error("Mock connection error"); + }); + + auto inputVector = makeFlatVector({1, 2, 3}); + + // Verify the exception is propagated. + VELOX_ASSERT_THROW( + evaluate>( + "mock_throwing(c0)", makeRowVector({inputVector})), + "Mock connection error"); +} + +TEST_F(MockRemoteFunctionTest, retryOnTransportError) { + // Register a mock remote function. + auto signatures = {exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .argumentType("bigint") + .build()}; + registerMockRemoteFunction("mock_retry", signatures); + + // Configure retry settings to minimize test time. + FLAGS_remote_function_retry_count = 3; + FLAGS_remote_function_retry_max_backoff_sec = 1; + + int callCount = 0; + + // Set up the mock to fail once with a transport error, then succeed. + EXPECT_CALL(*mockClient_, invokeFunction(testing::_, testing::_)) + .WillOnce([&callCount]( + remote::RemoteFunctionResponse&, + const remote::RemoteFunctionRequest&) { + ++callCount; + throw apache::thrift::transport::TTransportException( + apache::thrift::transport::TTransportException::NOT_OPEN, + "Connection refused"); + }) + .WillOnce([this, &callCount]( + remote::RemoteFunctionResponse& response, + const remote::RemoteFunctionRequest&) { + ++callCount; + auto resultVector = makeRowVector({makeFlatVector({2, 4, 6})}); + setMockResponse(response, resultVector); + }); + + auto inputVector = makeFlatVector({1, 2, 3}); + auto results = evaluate>( + "mock_retry(c0, c0)", makeRowVector({inputVector})); + + // Verify the function succeeded after retry. + auto expected = makeFlatVector({2, 4, 6}); + assertEqualVectors(expected, results); + + // Verify the mock was called twice (one failure + one success). + EXPECT_EQ(callCount, 2); +} + +TEST_F(MockRemoteFunctionTest, retryExhausted) { + // Register a mock remote function. + auto signatures = {exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .build()}; + registerMockRemoteFunction("mock_retry_exhausted", signatures); + + // Configure retry settings to minimize test time. + FLAGS_remote_function_retry_count = 2; + FLAGS_remote_function_retry_max_backoff_sec = 1; + + int callCount = 0; + + // Set up the mock to always fail with transport errors. + EXPECT_CALL(*mockClient_, invokeFunction(testing::_, testing::_)) + .WillRepeatedly([&callCount]( + remote::RemoteFunctionResponse&, + const remote::RemoteFunctionRequest&) { + ++callCount; + throw apache::thrift::transport::TTransportException( + apache::thrift::transport::TTransportException::NOT_OPEN, + "Connection refused"); + }); + + auto inputVector = makeFlatVector({1, 2, 3}); + + // Verify the exception is propagated after retries are exhausted. + VELOX_ASSERT_THROW( + evaluate>( + "mock_retry_exhausted(c0)", makeRowVector({inputVector})), + "Connection refused"); + + // Verify the mock was called retry_count + 1 times (initial + retries). + EXPECT_EQ(callCount, FLAGS_remote_function_retry_count + 1); +} + VELOX_INSTANTIATE_TEST_SUITE_P( RemoteFunctionTestFixture, RemoteFunctionTest,