Skip to content

Commit 1c1f73b

Browse files
authored
Feature/dynamic recurrent op forward test (#4729)
1 parent 6316b40 commit 1c1f73b

File tree

9 files changed

+323
-69
lines changed

9 files changed

+323
-69
lines changed

paddle/framework/tensor_array.cc

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
7676
const std::vector<DySeqMeta>& meta, const LoD& lod,
7777
size_t level);
7878

79+
std::vector<size_t> GenDyBatchIndice(const DySeqMetaBatch& meta, int batch_id) {
80+
// collect indice need to copy to the batch
81+
std::vector<size_t> indice;
82+
for (const auto& seq : meta) {
83+
size_t id = seq.begin + batch_id;
84+
if (id >= seq.end) break;
85+
indice.push_back(id);
86+
}
87+
return indice;
88+
}
89+
7990
} // namespace detail
8091

8192
const LoDTensor& TensorArray::Read(size_t index) const {
@@ -113,8 +124,8 @@ LoDTensor TensorArray::Pack(size_t level, const std::vector<DySeqMeta>& meta,
113124
return detail::PackDynamicBatch(values_, meta, lod, level);
114125
}
115126

116-
std::vector<DySeqMeta> TensorArray::Unpack(const LoDTensor& source, int level,
117-
bool length_desend) {
127+
DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level,
128+
bool length_desend) {
118129
detail::DynamicBatchUnpacker unpacker(source, level,
119130
length_desend /*descend*/);
120131

@@ -129,6 +140,7 @@ std::vector<DySeqMeta> TensorArray::Unpack(const LoDTensor& source, int level,
129140
Write(batch_id, unpacker.GetBatch(batch_id));
130141
}
131142

143+
PADDLE_ENFORCE(!unpacker.meta.empty());
132144
return unpacker.meta;
133145
}
134146

@@ -218,13 +230,7 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
218230
PADDLE_ENFORCE(!meta.empty(), "should build meta first");
219231
LoDTensor result;
220232

221-
// collect indice need to copy to the batch
222-
std::vector<size_t> indice;
223-
for (const auto& seq : meta) {
224-
size_t id = seq.begin + index;
225-
if (id >= seq.end) break;
226-
indice.push_back(id);
227-
}
233+
auto indice = detail::GenDyBatchIndice(meta, index);
228234
PADDLE_ENFORCE(!indice.empty(), "invalid batch at %d", index);
229235

230236
// copy the indice of records in LoDTensor
@@ -237,9 +243,9 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
237243
for (size_t i = 0; i < indice.size(); i++) {
238244
auto index = indice[i];
239245
auto target = result.Slice<value_type>(i, i + 1);
240-
auto source_ = source->Slice<value_type>(index, index + 1);
246+
auto slice = source->Slice<value_type>(index, index + 1);
241247

242-
target.CopyFrom<value_type>(source_, platform::CPUPlace(),
248+
target.CopyFrom<value_type>(slice, platform::CPUPlace(),
243249
platform::CPUDeviceContext());
244250
}
245251

paddle/framework/tensor_array.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ struct DySeqMeta {
3434
size_t ori_idx;
3535
};
3636

37+
using DySeqMetaBatch = std::vector<DySeqMeta>;
38+
39+
/*
40+
* Extract the indices of instances.
41+
*/
42+
std::vector<size_t> GenDyBatchIndice(const DySeqMetaBatch &metas, int batch_id);
43+
3744
/*
3845
* TensorArray is a C-array-like array of tensors, it is meant to be used with
3946
* dynamic iteration primitives such as while_loop. It is used to segment inputs
@@ -69,16 +76,15 @@ class TensorArray {
6976
* Recover the original LoD-arranged LoDTensor with the `values`, `level` and
7077
* `indice_map`.
7178
*/
72-
LoDTensor Pack(size_t level, const std::vector<DySeqMeta> &meta,
79+
LoDTensor Pack(size_t level, const DySeqMetaBatch &meta,
7380
const LoD &lod) const;
7481

7582
/*
7683
* Split LoDTensor in some `level` and write the generated batches to
7784
* `values`, if set `desend`, will sort by length in descending order else in
7885
* ascending order.
7986
*/
80-
std::vector<DySeqMeta> Unpack(const LoDTensor &source, int level,
81-
bool length_desend);
87+
DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend);
8288

8389
/*
8490
* Pack the values into a tensor with rank one higher than each tensor in

paddle/operators/dynamic_recurrent_op.cc

Lines changed: 110 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ using framework::Scope;
2323
using framework::TensorArray;
2424
using framework::LoDTensor;
2525
using framework::Variable;
26+
using framework::DySeqMetaBatch;
2627

2728
namespace detail {
2829

@@ -33,6 +34,29 @@ inline void CreateVariables(Scope& scope,
3334
}
3435
}
3536

37+
/*
38+
* The inputs with sequence should be reordered when they are split, so the
39+
* boot_states should be reordered in the same order.
40+
*
41+
* NOTE This may require that the `pre_state` of the first time step should just
42+
* copy the `boot_state` rather than reference it, for that the content should
43+
* be reordered, but the RNN op should not change the `boot_state` as an input
44+
* variable's content.
45+
*/
46+
template <typename T>
47+
inline void ReorderBootState(const DySeqMetaBatch& metas,
48+
const LoDTensor& boot_state, LoDTensor* tensor,
49+
const platform::Place& dst_place) {
50+
for (size_t seq_id = 0; seq_id < metas.size(); seq_id++) {
51+
auto slice = tensor->Slice<T>(seq_id, seq_id + 1);
52+
auto boot_slice =
53+
boot_state.Slice<T>(metas[seq_id].ori_idx, metas[seq_id].ori_idx + 1);
54+
// TODO(superjom) pass in device context as an argument
55+
slice.template CopyFrom<T>(boot_slice, dst_place,
56+
platform::CPUDeviceContext());
57+
}
58+
}
59+
3660
} // namespace detail
3761

3862
class DynamicRecurrentOpProtoAndCheckerMaker
@@ -69,26 +93,26 @@ void DynamicRecurrentOp::Run(const Scope& scope,
6993
CreateScopes();
7094
WriteStepInputs();
7195
InitStates();
96+
WriteStepOutputs();
7297

7398
// call stepnet in all the time steps
7499
for (size_t step = 0; step < cache_.num_steps; step++) {
75100
auto& step_scope = cache_.GetScope(step);
76101
stepnet_->Run(step_scope, dev_ctx);
77102
}
78103

79-
WriteStepOutputs();
80104
ConcatOutputs();
81105
}
82106

83107
void DynamicRecurrentOp::SplitInputs() const {
84108
// TODO(superjom) make level a config
85109
// TODO(superjom) check all the inputs has the same LoD
86110
int level = 0;
87-
const auto& inlinks = cache_.inlinks;
88-
for (const auto& item : inlinks) {
111+
for (const auto& item : cache_.inlinks) {
89112
const auto& var = item.second;
90113
const auto& tensor = var->Get<LoDTensor>();
91114
TensorArray& ta = step_inputs_[item.first];
115+
92116
dy_seq_metas_[item.first] =
93117
ta.Unpack(tensor, level, true /*length_descend*/);
94118

@@ -120,17 +144,11 @@ void DynamicRecurrentOp::WriteStepInputs() const {
120144
}
121145

122146
void DynamicRecurrentOp::WriteStepOutputs() const {
123-
for (size_t step = 0; step < cache_.scopes->size(); step++) {
124-
auto& scope = cache_.GetScope(step);
125-
for (auto& item : step_outputs_) {
126-
auto* var = scope.FindVar(item.first);
127-
if (var == nullptr) {
128-
var = scope.NewVar(item.first);
129-
}
130-
auto* tensor = var->GetMutable<LoDTensor>();
131-
item.second.WriteShared(step, *tensor);
132-
}
147+
// initialize step outputs
148+
for (const auto& item : cache_.outlinks) {
149+
step_outputs_.emplace(item.first, TensorArray());
133150
}
151+
PADDLE_ENFORCE_GT(step_outputs_.size(), 0UL);
134152
}
135153

136154
void DynamicRecurrentOp::CreateScopes() const {
@@ -145,73 +163,107 @@ void DynamicRecurrentOp::CreateScopes() const {
145163
PADDLE_ENFORCE_NOT_NULL(stepnet_, "stepnet should be set first");
146164
std::vector<std::string> memories;
147165
std::vector<std::string> pre_memories;
166+
std::vector<std::string> stepnet_outputs;
148167
std::transform(arg_.memories.begin(), arg_.memories.end(),
149168
std::back_inserter(memories),
150169
[](const rnn::MemoryAttr& m) { return m.var; });
151170
std::transform(arg_.memories.begin(), arg_.memories.end(),
152171
std::back_inserter(pre_memories),
153172
[](const rnn::MemoryAttr& m) { return m.pre_var; });
173+
for (const auto& item : stepnet_->Outputs()) {
174+
for (const auto& var : item.second) {
175+
stepnet_outputs.push_back(var);
176+
}
177+
}
154178

155179
for (size_t step = 0; step < cache_.num_steps; step++) {
156180
auto& scope = cache_.GetScope(step);
157181
detail::CreateVariables(scope, arg_.inlinks);
158182
detail::CreateVariables(scope, arg_.outlinks);
159183
detail::CreateVariables(scope, memories);
160184
detail::CreateVariables(scope, pre_memories);
185+
detail::CreateVariables(scope, stepnet_outputs);
161186
}
162187
}
163188

164189
void DynamicRecurrentOp::ConcatOutputs() const {
165190
// TODO(superjom) transform this to a config
166191
int level = 0;
167-
// TODO(superjom) pass in some lod
168-
// just a placeholder
169-
framework::LoD lod;
192+
for (size_t step = 0; step < cache_.num_steps; step++) {
193+
auto& scope = cache_.GetScope(step);
194+
for (auto& item : step_outputs_) {
195+
auto* var = scope.FindVar(item.first);
196+
PADDLE_ENFORCE_NOT_NULL(var);
197+
auto* tensor = var->GetMutable<LoDTensor>();
198+
tensor->mutable_data<value_type>(platform::CPUPlace());
199+
item.second.WriteShared(step, *tensor);
200+
}
201+
}
202+
// the inlinks' lods should be the same, so randomly get one lod.
203+
const auto& some_lod =
204+
cache_.scope->FindVar(arg_.inlinks.front())->Get<LoDTensor>().lod();
205+
const auto& some_meta = dy_seq_metas_[arg_.inlinks.front()];
170206
for (auto& item : step_outputs_) {
171-
auto tensor = item.second.Pack(level, dy_seq_metas_[item.first], lod);
172-
auto& output = cache_.outlinks[item.first]->Get<LoDTensor>();
173-
const_cast<LoDTensor*>(&output)->ShareDataWith<value_type>(tensor);
207+
auto tensor = item.second.Pack(level, some_meta, some_lod);
208+
auto* output = cache_.outlinks[item.first]->GetMutable<LoDTensor>();
209+
const_cast<LoDTensor*>(output)->ShareDataWith<value_type>(tensor);
174210
}
175211
}
176212

177213
void DynamicRecurrentOp::InitStates() const {
178-
// init the first state
179-
// TODO(superjom) parepare the scenerio that boot state not exists
180-
for (auto memory : arg_.memories) {
181-
auto* boot_state_var = cache_.scope->FindVar(memory.boot_var);
182-
PADDLE_ENFORCE_NOT_NULL(boot_state_var);
183-
auto& boot_state = boot_state_var->Get<LoDTensor>();
184-
const auto& dims = boot_state.dims();
185-
186-
for (size_t step = 0; step < cache_.num_steps; step++) {
187-
auto& cur_scope = cache_.GetScope(step);
188-
// link pre-state to boot_state
189-
// init state and pre-state
190-
auto* pre_state = cur_scope.FindVar(memory.pre_var);
191-
PADDLE_ENFORCE_NOT_NULL(pre_state);
192-
pre_state->GetMutable<LoDTensor>();
193-
194-
auto* state = cur_scope.FindVar(memory.var);
195-
PADDLE_ENFORCE_NOT_NULL(state);
196-
state->GetMutable<LoDTensor>()->Resize(dims);
197-
state->GetMutable<LoDTensor>()->mutable_data<value_type>(
198-
platform::CPUPlace());
199-
200-
if (step == 0) {
201-
auto* pre_state_tensor = pre_state->GetMutable<LoDTensor>();
202-
pre_state_tensor->Resize(boot_state.dims());
203-
pre_state_tensor->ShareDataWith<value_type>(boot_state);
204-
} else {
205-
auto& pre_scope = cache_.GetScope(step - 1);
206-
auto* state_pre = pre_scope.FindVar(memory.var);
207-
PADDLE_ENFORCE_NOT_NULL(state_pre);
208-
pre_state->GetMutable<LoDTensor>()->ShareDataWith<value_type>(
209-
*state_pre->GetMutable<LoDTensor>());
210-
}
214+
for (size_t step = 0; step < cache_.num_steps; step++) {
215+
for (const auto& memory : arg_.memories) {
216+
CreateState(memory, step);
217+
LinkState(memory, step);
211218
}
212219
}
213220
}
214221

222+
void DynamicRecurrentOp::CreateState(const rnn::MemoryAttr& memory,
223+
size_t step) const {
224+
auto& scope = cache_.GetScope(step);
225+
auto& state = *cache_.GetTensor(scope, memory.var);
226+
auto& boot_state = *cache_.GetTensor(*cache_.scope, memory.boot_var);
227+
228+
size_t num_instances =
229+
step_inputs_[arg_.inlinks.front()].Read(step).dims()[0];
230+
auto dims = boot_state.dims();
231+
dims[0] = num_instances;
232+
233+
state.Resize(dims);
234+
state.mutable_data<value_type>(platform::CPUPlace());
235+
states_[memory.var].WriteShared(step, state);
236+
}
237+
238+
void DynamicRecurrentOp::LinkState(const rnn::MemoryAttr& memory,
239+
size_t step) const {
240+
auto& scope = cache_.GetScope(step);
241+
auto& state_pre = *cache_.GetTensor(scope, memory.pre_var);
242+
243+
// all the step_inputs' metas should be the same, just randomly select one
244+
// and get the dyseq meta.
245+
const auto& some_meta = dy_seq_metas_[arg_.inlinks.front()];
246+
size_t num_instances =
247+
step_inputs_[arg_.inlinks.front()].Read(step).dims()[0];
248+
249+
LoDTensor* pre_state{nullptr};
250+
if (step == 0) {
251+
pre_state = cache_.GetTensor(*cache_.scope, memory.boot_var);
252+
pre_state->mutable_data<float>(platform::CPUPlace());
253+
// allocate memory
254+
state_pre.Resize(pre_state->dims());
255+
state_pre.mutable_data<value_type>(platform::CPUPlace());
256+
detail::ReorderBootState<value_type>(some_meta, *pre_state, &state_pre,
257+
pre_state->place());
258+
} else {
259+
pre_state = cache_.GetTensor(cache_.GetScope(step - 1), memory.var);
260+
}
261+
262+
// shink and share from previous state
263+
auto shrinked_pre_state = pre_state->Slice<value_type>(0, num_instances);
264+
state_pre.ShareDataWith<value_type>(shrinked_pre_state);
265+
}
266+
215267
void DynamicRecurrentOp::ArgCache::Init(
216268
const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op,
217269
const paddle::framework::Scope& scope, rnn::Argument* arg) {
@@ -261,6 +313,12 @@ Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope,
261313
return var;
262314
}
263315

316+
LoDTensor* DynamicRecurrentOp::ArgCache::GetTensor(
317+
const framework::Scope& scope, const std::string& name) {
318+
auto* var = GetVariable(scope, name);
319+
return var->GetMutable<LoDTensor>();
320+
}
321+
264322
const rnn::ArgumentName DynamicRecurrentOp::kArgName{
265323
"step_net", "step_scopes", "inlinks", "outlinks",
266324
"memories", "pre_memories", "boot_memories"};

0 commit comments

Comments
 (0)