Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}")

if(WITH_DISTRIBUTE)
add_definitions(-DPADDLE_WITH_DISTRIBUTE)
endif()

if(WITH_GOLANG)
# we need to symlink Paddle directory into GOPATH. If we
# don't do it and we have code that depends on Paddle, go
Expand Down
9 changes: 7 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)

cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)

cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
endif()


cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/detail/grpc_client.h"
#endif
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"

Expand All @@ -44,6 +47,14 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {

Executor::Executor(const platform::Place& place) : place_(place) {}

#ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() {
::paddle::operators::detail::RPCClient::GetInstance<
::paddle::operators::detail::GRPCClient>()
->SendComplete();
}
#endif

void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ class Executor {

explicit Executor(const platform::Place& place);

#ifdef PADDLE_WITH_DISTRIBUTE
/*
* Sending signal to pserver to mark current trainer stop.
*/
void Complete();
#endif

/* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope
*
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ void GRPCClient::InitEventLoop() {
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
}

void GRPCClient::SendComplete() {
for (auto& it : channels_) {
this->AsyncSendComplete(it.first);
}
}

GRPCClient::~GRPCClient() {
Wait();
cq_.Shutdown();
Expand Down Expand Up @@ -210,6 +216,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_++;
}

void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);

BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);

sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}

void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/detail/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ class GRPCClient : public RPCClient {

void Wait() override;

void SendComplete() override;

protected:
void InitImpl() override;

Expand All @@ -204,6 +206,9 @@ class GRPCClient : public RPCClient {

void Proceed();

void AsyncSendComplete(const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out);

std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);

private:
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/detail/request_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"

class RPCServer;

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/detail/request_handler_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
rpc_server_->DecreaseClientNum();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart!

} else {
VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/detail/rpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class RPCClient {
virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;

// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
// to train with other trainers.
virtual void SendComplete() = 0;

virtual void Wait() = 0;

static constexpr int64_t rpc_time_out = 120 * 1000;
Expand Down
22 changes: 13 additions & 9 deletions paddle/fluid/operators/detail/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void RPCServer::SavePort() const {

void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [=] {
barrier_cond_.wait(lock, [this, &rpc_name] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
});

Expand All @@ -53,19 +53,23 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) {
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
}

VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name
<< ", barrier_count:" << b << ", fan_in" << client_num_;

std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
if (b >= client_num_) {
lock.unlock();
barrier_cond_.notify_all();
lock.lock();
}
}

void RPCServer::DecreaseClientNum() {
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_--;
}
barrier_cond_.notify_all();
}

void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_);
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/operators/detail/rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RPCServer {
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);

void DecreaseClientNum();
void ResetBarrierCounter();

protected:
Expand All @@ -79,8 +79,7 @@ class RPCServer {
std::string bind_address_;
std::atomic<int> exit_flag_;
int selected_port_;

const int client_num_;
int client_num_;

std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ All parameter, weight, gradient are variables in Paddle.

py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE
.def("complete", &Executor::Complete)
#endif
.def("run",
(void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) &
Executor::Run);
Expand Down