Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
return known_vars;
}

void Scope::DeleteScope(Scope* scope) {
void Scope::DeleteScope(Scope* scope) const {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Scope {
/// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const;

void DeleteScope(Scope* scope);
void DeleteScope(Scope* scope) const;

/// Drop all kids scopes belonged to this scope.
void DropKids();
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RequestSend final : public RequestBase {
framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
request_.reset(new VariableResponse(scope, dev_ctx_));
request_.reset(new VariableResponse(false, scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
Expand Down Expand Up @@ -146,7 +146,7 @@ class RequestPrefetch final : public RequestBase {
executor_(executor),
program_(program),
prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(scope, dev_ctx_));
request_.reset(new VariableResponse(false, scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/sendrecvop_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
operators::detail::VariableResponse resp(scope, &ctx);
operators::detail::VariableResponse resp(false, scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/detail/serde_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope;
scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx);
operators::detail::VariableResponse resp(false, &scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0);

framework::Variable* var2 = resp.GetVar();
Expand Down Expand Up @@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy
framework::Scope scope;
scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx);
operators::detail::VariableResponse resp(false, &scope, &ctx);
if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0);
} else {
Expand Down
9 changes: 3 additions & 6 deletions paddle/fluid/operators/detail/variable_response.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto var = scope_->FindVar(meta_.varname());
auto* tensor = var->GetMutable<framework::LoDTensor>();
auto* tensor = InitVar()->GetMutable<framework::LoDTensor>();
tensor->Resize(dims);

framework::LoD lod;
Expand Down Expand Up @@ -151,8 +150,7 @@ bool VariableResponse::CopySelectRowsTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
auto* slr = InitVar()->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
Expand All @@ -174,8 +172,7 @@ bool VariableResponse::CopySelectRowsTensorData(
bool VariableResponse::CopySelectRowsData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
auto* slr = InitVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->resize(length /
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data();
Expand Down
26 changes: 22 additions & 4 deletions paddle/fluid/operators/detail/variable_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ namespace detail {

class VariableResponse {
public:
VariableResponse(const framework::Scope* scope,
VariableResponse(bool use_local_scope, const framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
: scope_(scope), dev_ctx_(dev_ctx) {}
: use_local_scope_(use_local_scope), scope_(scope), dev_ctx_(dev_ctx) {
local_scope_ = &scope->NewScope();
}

virtual ~VariableResponse() {}
virtual ~VariableResponse() { scope_->DeleteScope(local_scope_); }

// return:
// 0:ok.
Expand All @@ -54,11 +56,25 @@ class VariableResponse {
// other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer);

const framework::Scope& GetLocalScope() const { return *local_scope_; }

inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }

// should call parse first.
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
framework::Variable* GetVar() {
return local_scope_->FindVar(meta_.varname());
}

framework::Variable* InitVar() {
if (use_local_scope_) {
bool has_var = (scope_->FindVar(meta_.varname()) != nullptr);
PADDLE_ENFORCE(has_var);
return local_scope_->Var(meta_.varname());
} else {
return scope_->FindVar(meta_.varname());
}
}

private:
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
Expand All @@ -73,7 +89,9 @@ class VariableResponse {
const framework::DDim& dims, int length);

private:
bool use_local_scope_ = false;
const framework::Scope* scope_;
framework::Scope* local_scope_ = nullptr;
const platform::DeviceContext* dev_ctx_;
// only Skeleton
sendrecv::VariableMessage meta_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/split_byref_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SplitByrefOpKernel : public framework::OpKernel<T> {
// NOTE: no need to call mutable_data here to allocate memory.
auto* out = outs[i];
VLOG(3) << "spliting by ref: " << row_offset << " " << out->dims()[0];
*out = std::move(in->Slice(row_offset, row_offset + out->dims()[0]));
*out = in->Slice(row_offset, row_offset + out->dims()[0]);
row_offset += out->dims()[0];
}
}
Expand Down