Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions paddle/fluid/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}

bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);

BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
Expand All @@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;
}

return true;
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);

sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;
}

bool RPCClient::Wait() {
Expand Down Expand Up @@ -154,7 +164,7 @@ bool RPCClient::Proceed() {
PADDLE_ENFORCE(tag);

// TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag);
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String()
<< " grpc error:" << c->status_.error_message();
Expand All @@ -174,6 +184,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
}

grpc::ChannelArguments args;
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

Expand Down
36 changes: 26 additions & 10 deletions paddle/fluid/operators/detail/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ struct VarHandle {
void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& msg);

class ClientBase {
class BaseProcessor {
public:
explicit ClientBase(std::shared_ptr<grpc::Channel> ch) {
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
context_ = NULL;
}

virtual ~ClientBase() {}
virtual ~BaseProcessor() {}

virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
context_.reset(new grpc::ClientContext());
Expand Down Expand Up @@ -91,9 +91,10 @@ class ClientBase {
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
RequestSendCallBack;

class SendProcessor : public ClientBase {
class SendProcessor : public BaseProcessor {
public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}

virtual ~SendProcessor() {}

Expand All @@ -110,9 +111,10 @@ class SendProcessor : public ClientBase {
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
RequestGetCallBack;

class GetProcessor : public ClientBase {
class GetProcessor : public BaseProcessor {
public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}

virtual ~GetProcessor() {}

Expand All @@ -126,17 +128,28 @@ class GetProcessor : public ClientBase {
RequestGetCallBack response_call_back_ = ProcGetResponse;
};

class BatchBarrierProcessor : public ClientBase {
class BatchBarrierProcessor : public BaseProcessor {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: ClientBase(ch) {}
: BaseProcessor(ch) {}

virtual ~BatchBarrierProcessor() {}

virtual void Process() {}
sendrecv::VoidMessage reply_;
};

class FetchBarrierProcessor : public BaseProcessor {
public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}

virtual ~FetchBarrierProcessor() {}

virtual void Process() {}
sendrecv::VariableMessage reply_;
};

class RPCClient {
public:
bool AsyncSendVariable(const std::string& ep,
Expand All @@ -151,7 +164,10 @@ class RPCClient {
const std::string& var_name,
int64_t time_out = 600 * 1000);

bool AsyncSendBatchBarrier(const std::string& ep,
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);

void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);

bool Wait();
Expand Down
21 changes: 15 additions & 6 deletions paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class RequestGet final : public RequestBase {
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<char>* queue)
SimpleBlockQueue<MessageWithName>* queue)
: RequestBase(service, cq),
responder_(&ctx_),
scope_(scope),
Expand All @@ -101,11 +101,16 @@ class RequestGet final : public RequestBase {
// proc request.
std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name);
SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
if (var_name != FETCH_BARRIER_MESSAGE) {
SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
}
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
queue_->Push('c');
MessageWithName msg_with_name =
// request name reply
std::make_pair(var_name, std::move(reply_));
queue_->Push(msg_with_name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe only push FETCH_BARRIER_MESSAGE message.

}

protected:
Expand All @@ -114,12 +119,16 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<char>* queue_;
SimpleBlockQueue<MessageWithName>* queue_;
};

void AsyncGRPCServer::WaitClientGet(int count) {
for (int i = 0; i < count; ++i) {
var_get_queue_.Pop();
int fetch_barriers = 0;
while (fetch_barriers < count) {
auto msg = var_get_queue_.Pop();
if (msg.first == FETCH_BARRIER_MESSAGE) {
fetch_barriers++;
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<char> var_get_queue_;
SimpleBlockQueue<MessageWithName> var_get_queue_;

// condition of the sub program
std::mutex barrier_mutex_;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/detail/sendrecvop_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace detail {

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"

typedef void (*DestroyCallback)(void*);

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase {
}
}
if (exit_flag) {
rpc_service_->ShutDown();
rpc_service_->SetCond(1);
rpc_service_->ShutDown();
break;
}
try {
Expand All @@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase {
}
rpc_service_->SetCond(1);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(ins.size());
rpc_service_->WaitClientGet(fan_in);
sparse_vars.clear();
} // while(true)
}
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase {
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
PADDLE_ENFORCE(rpc_client->Wait());
// tell pservers that current trainer have called fetch
for (auto& ep : endpoints) {
VLOG(3) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
}
};
Expand Down
20 changes: 17 additions & 3 deletions python/paddle/fluid/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def transpile(self,
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.program.__str__()
return self.program

def get_pserver_program(self, endpoint):
Expand Down Expand Up @@ -309,7 +311,8 @@ def get_pserver_program(self, endpoint):
for _, opt_op in enumerate(opt_op_on_pserver):
if ufind.is_connected(op, opt_op):
if self._is_opt_op(op):
self._append_pserver_ops(optimize_block, op, endpoint)
self._append_pserver_ops(optimize_block, op, endpoint,
default_main_program())
else:
self._append_pserver_non_opt_ops(optimize_block, op)
break
Expand Down Expand Up @@ -520,7 +523,8 @@ def _orig_varname(self, varname):
orig_var_name = varname[:suff_idx]
return orig_var_name

def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
origin_program):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict()
Expand Down Expand Up @@ -576,7 +580,17 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
elif key == "LearningRate":
# leraning rate variable has already be created by non-optimize op,
# don't create it once again.
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
lr_varname = opt_op.input(key)[0]
if pserver_block.vars.has_key(lr_varname):
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
else:
origin_var = origin_program.global_block().vars[lr_varname]
tmpvar = pserver_block.create_var(
name=origin_var.name,
persistable=origin_var.persistable,
dtype=origin_var.dtype,
shape=origin_var.shape)
new_inputs[key] = tmpvar

for key in opt_op.input_names:
new_shape = None
Expand Down