Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
return known_vars;
}

void Scope::DeleteScope(Scope* scope) {
void Scope::DeleteScope(Scope* scope) const {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Scope {
/// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const;

void DeleteScope(Scope* scope);
void DeleteScope(Scope* scope) const;

/// Drop all kids scopes belonged to this scope.
void DropKids();
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RequestSend final : public RequestBase {
framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
request_.reset(new VariableResponse(scope, dev_ctx_));
request_.reset(new VariableResponse(false, scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
Expand Down Expand Up @@ -146,7 +146,7 @@ class RequestPrefetch final : public RequestBase {
executor_(executor),
program_(program),
prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(scope, dev_ctx_));
request_.reset(new VariableResponse(false, scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/sendrecvop_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
operators::detail::VariableResponse resp(scope, &ctx);
operators::detail::VariableResponse resp(false, scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/detail/serde_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope;
scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx);
operators::detail::VariableResponse resp(false, &scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0);

framework::Variable* var2 = resp.GetVar();
Expand Down Expand Up @@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy
framework::Scope scope;
scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx);
operators::detail::VariableResponse resp(false, &scope, &ctx);
if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0);
} else {
Expand Down
9 changes: 3 additions & 6 deletions paddle/fluid/operators/detail/variable_response.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto var = scope_->FindVar(meta_.varname());
auto* tensor = var->GetMutable<framework::LoDTensor>();
auto* tensor = InitVar()->GetMutable<framework::LoDTensor>();
tensor->Resize(dims);

framework::LoD lod;
Expand Down Expand Up @@ -151,8 +150,7 @@ bool VariableResponse::CopySelectRowsTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
auto* slr = InitVar()->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
Expand All @@ -174,8 +172,7 @@ bool VariableResponse::CopySelectRowsTensorData(
bool VariableResponse::CopySelectRowsData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
auto* slr = InitVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->resize(length /
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data();
Expand Down
24 changes: 20 additions & 4 deletions paddle/fluid/operators/detail/variable_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ namespace detail {

class VariableResponse {
public:
VariableResponse(const framework::Scope* scope,
VariableResponse(bool use_local_scope, const framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
: scope_(scope), dev_ctx_(dev_ctx) {}
: use_local_scope_(use_local_scope), scope_(scope), dev_ctx_(dev_ctx) {
local_scope_ = &scope->NewScope();
}

virtual ~VariableResponse() {}
virtual ~VariableResponse() { scope_->DeleteScope(local_scope_); }

// return:
// 0:ok.
Expand All @@ -54,11 +56,23 @@ class VariableResponse {
// other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer);

const framework::Scope& GetLocalScope() const { return *local_scope_; }

inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }

// should call parse first.
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
framework::Variable* GetVar() {
return local_scope_->FindVar(meta_.varname());
}

framework::Variable* InitVar() {
if (use_local_scope_) {
return local_scope_->Var(meta_.varname());
} else {
return scope_->FindVar(meta_.varname());
}
}

private:
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
Expand All @@ -73,7 +87,9 @@ class VariableResponse {
const framework::DDim& dims, int length);

private:
bool use_local_scope_ = false;
const framework::Scope* scope_;
framework::Scope* local_scope_ = nullptr;
const platform::DeviceContext* dev_ctx_;
// only Skeleton
sendrecv::VariableMessage meta_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/split_byref_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SplitByrefOpKernel : public framework::OpKernel<T> {
// NOTE: no need to call mutable_data here to allocate memory.
auto* out = outs[i];
VLOG(3) << "spliting by ref: " << row_offset << " " << out->dims()[0];
*out = std::move(in->Slice(row_offset, row_offset + out->dims()[0]));
*out = in->Slice(row_offset, row_offset + out->dims()[0]);
row_offset += out->dims()[0];
}
}
Expand Down