@@ -23,6 +23,7 @@ using framework::Scope;
2323using framework::TensorArray;
2424using framework::LoDTensor;
2525using framework::Variable;
26+ using framework::DySeqMetaBatch;
2627
2728namespace detail {
2829
@@ -33,6 +34,29 @@ inline void CreateVariables(Scope& scope,
3334 }
3435}
3536
37+ /*
38+ * The inputs with sequence should be reordered when they are split, so the
39+ * boot_states should be reordered in the same order.
40+ *
41+ * NOTE This may require that the `pre_state` of the first time step should just
42+ * copy the `boot_state` rather than reference it, for that the content should
43+ * be reordered, but the RNN op should not change the `boot_state` as an input
44+ * variable's content.
45+ */
46+ template <typename T>
47+ inline void ReorderBootState (const DySeqMetaBatch& metas,
48+ const LoDTensor& boot_state, LoDTensor* tensor,
49+ const platform::Place& dst_place) {
50+ for (size_t seq_id = 0 ; seq_id < metas.size (); seq_id++) {
51+ auto slice = tensor->Slice <T>(seq_id, seq_id + 1 );
52+ auto boot_slice =
53+ boot_state.Slice <T>(metas[seq_id].ori_idx , metas[seq_id].ori_idx + 1 );
54+ // TODO(superjom) pass in device context as an argument
55+ slice.template CopyFrom <T>(boot_slice, dst_place,
56+ platform::CPUDeviceContext ());
57+ }
58+ }
59+
3660} // namespace detail
3761
3862class DynamicRecurrentOpProtoAndCheckerMaker
@@ -69,26 +93,26 @@ void DynamicRecurrentOp::Run(const Scope& scope,
6993 CreateScopes ();
7094 WriteStepInputs ();
7195 InitStates ();
96+ WriteStepOutputs ();
7297
7398 // call stepnet in all the time steps
7499 for (size_t step = 0 ; step < cache_.num_steps ; step++) {
75100 auto & step_scope = cache_.GetScope (step);
76101 stepnet_->Run (step_scope, dev_ctx);
77102 }
78103
79- WriteStepOutputs ();
80104 ConcatOutputs ();
81105}
82106
83107void DynamicRecurrentOp::SplitInputs () const {
84108 // TODO(superjom) make level a config
85109 // TODO(superjom) check all the inputs has the same LoD
86110 int level = 0 ;
87- const auto & inlinks = cache_.inlinks ;
88- for (const auto & item : inlinks) {
111+ for (const auto & item : cache_.inlinks ) {
89112 const auto & var = item.second ;
90113 const auto & tensor = var->Get <LoDTensor>();
91114 TensorArray& ta = step_inputs_[item.first ];
115+
92116 dy_seq_metas_[item.first ] =
93117 ta.Unpack (tensor, level, true /* length_descend*/ );
94118
@@ -120,17 +144,11 @@ void DynamicRecurrentOp::WriteStepInputs() const {
120144}
121145
122146void DynamicRecurrentOp::WriteStepOutputs () const {
123- for (size_t step = 0 ; step < cache_.scopes ->size (); step++) {
124- auto & scope = cache_.GetScope (step);
125- for (auto & item : step_outputs_) {
126- auto * var = scope.FindVar (item.first );
127- if (var == nullptr ) {
128- var = scope.NewVar (item.first );
129- }
130- auto * tensor = var->GetMutable <LoDTensor>();
131- item.second .WriteShared (step, *tensor);
132- }
147+ // initialize step outputs
148+ for (const auto & item : cache_.outlinks ) {
149+ step_outputs_.emplace (item.first , TensorArray ());
133150 }
151+ PADDLE_ENFORCE_GT (step_outputs_.size (), 0UL );
134152}
135153
136154void DynamicRecurrentOp::CreateScopes () const {
@@ -145,73 +163,107 @@ void DynamicRecurrentOp::CreateScopes() const {
145163 PADDLE_ENFORCE_NOT_NULL (stepnet_, " stepnet should be set first" );
146164 std::vector<std::string> memories;
147165 std::vector<std::string> pre_memories;
166+ std::vector<std::string> stepnet_outputs;
148167 std::transform (arg_.memories .begin (), arg_.memories .end (),
149168 std::back_inserter (memories),
150169 [](const rnn::MemoryAttr& m) { return m.var ; });
151170 std::transform (arg_.memories .begin (), arg_.memories .end (),
152171 std::back_inserter (pre_memories),
153172 [](const rnn::MemoryAttr& m) { return m.pre_var ; });
173+ for (const auto & item : stepnet_->Outputs ()) {
174+ for (const auto & var : item.second ) {
175+ stepnet_outputs.push_back (var);
176+ }
177+ }
154178
155179 for (size_t step = 0 ; step < cache_.num_steps ; step++) {
156180 auto & scope = cache_.GetScope (step);
157181 detail::CreateVariables (scope, arg_.inlinks );
158182 detail::CreateVariables (scope, arg_.outlinks );
159183 detail::CreateVariables (scope, memories);
160184 detail::CreateVariables (scope, pre_memories);
185+ detail::CreateVariables (scope, stepnet_outputs);
161186 }
162187}
163188
164189void DynamicRecurrentOp::ConcatOutputs () const {
165190 // TODO(superjom) transform this to a config
166191 int level = 0 ;
167- // TODO(superjom) pass in some lod
168- // just a placeholder
169- framework::LoD lod;
192+ for (size_t step = 0 ; step < cache_.num_steps ; step++) {
193+ auto & scope = cache_.GetScope (step);
194+ for (auto & item : step_outputs_) {
195+ auto * var = scope.FindVar (item.first );
196+ PADDLE_ENFORCE_NOT_NULL (var);
197+ auto * tensor = var->GetMutable <LoDTensor>();
198+ tensor->mutable_data <value_type>(platform::CPUPlace ());
199+ item.second .WriteShared (step, *tensor);
200+ }
201+ }
202+ // the inlinks' lods should be the same, so randomly get one lod.
203+ const auto & some_lod =
204+ cache_.scope ->FindVar (arg_.inlinks .front ())->Get <LoDTensor>().lod ();
205+ const auto & some_meta = dy_seq_metas_[arg_.inlinks .front ()];
170206 for (auto & item : step_outputs_) {
171- auto tensor = item.second .Pack (level, dy_seq_metas_[item. first ], lod );
172- auto & output = cache_.outlinks [item.first ]->Get <LoDTensor>();
173- const_cast <LoDTensor*>(& output)->ShareDataWith <value_type>(tensor);
207+ auto tensor = item.second .Pack (level, some_meta, some_lod );
208+ auto * output = cache_.outlinks [item.first ]->GetMutable <LoDTensor>();
209+ const_cast <LoDTensor*>(output)->ShareDataWith <value_type>(tensor);
174210 }
175211}
176212
177213void DynamicRecurrentOp::InitStates () const {
178- // init the first state
179- // TODO(superjom) parepare the scenerio that boot state not exists
180- for (auto memory : arg_.memories ) {
181- auto * boot_state_var = cache_.scope ->FindVar (memory.boot_var );
182- PADDLE_ENFORCE_NOT_NULL (boot_state_var);
183- auto & boot_state = boot_state_var->Get <LoDTensor>();
184- const auto & dims = boot_state.dims ();
185-
186- for (size_t step = 0 ; step < cache_.num_steps ; step++) {
187- auto & cur_scope = cache_.GetScope (step);
188- // link pre-state to boot_state
189- // init state and pre-state
190- auto * pre_state = cur_scope.FindVar (memory.pre_var );
191- PADDLE_ENFORCE_NOT_NULL (pre_state);
192- pre_state->GetMutable <LoDTensor>();
193-
194- auto * state = cur_scope.FindVar (memory.var );
195- PADDLE_ENFORCE_NOT_NULL (state);
196- state->GetMutable <LoDTensor>()->Resize (dims);
197- state->GetMutable <LoDTensor>()->mutable_data <value_type>(
198- platform::CPUPlace ());
199-
200- if (step == 0 ) {
201- auto * pre_state_tensor = pre_state->GetMutable <LoDTensor>();
202- pre_state_tensor->Resize (boot_state.dims ());
203- pre_state_tensor->ShareDataWith <value_type>(boot_state);
204- } else {
205- auto & pre_scope = cache_.GetScope (step - 1 );
206- auto * state_pre = pre_scope.FindVar (memory.var );
207- PADDLE_ENFORCE_NOT_NULL (state_pre);
208- pre_state->GetMutable <LoDTensor>()->ShareDataWith <value_type>(
209- *state_pre->GetMutable <LoDTensor>());
210- }
214+ for (size_t step = 0 ; step < cache_.num_steps ; step++) {
215+ for (const auto & memory : arg_.memories ) {
216+ CreateState (memory, step);
217+ LinkState (memory, step);
211218 }
212219 }
213220}
214221
222+ void DynamicRecurrentOp::CreateState (const rnn::MemoryAttr& memory,
223+ size_t step) const {
224+ auto & scope = cache_.GetScope (step);
225+ auto & state = *cache_.GetTensor (scope, memory.var );
226+ auto & boot_state = *cache_.GetTensor (*cache_.scope , memory.boot_var );
227+
228+ size_t num_instances =
229+ step_inputs_[arg_.inlinks .front ()].Read (step).dims ()[0 ];
230+ auto dims = boot_state.dims ();
231+ dims[0 ] = num_instances;
232+
233+ state.Resize (dims);
234+ state.mutable_data <value_type>(platform::CPUPlace ());
235+ states_[memory.var ].WriteShared (step, state);
236+ }
237+
238+ void DynamicRecurrentOp::LinkState (const rnn::MemoryAttr& memory,
239+ size_t step) const {
240+ auto & scope = cache_.GetScope (step);
241+ auto & state_pre = *cache_.GetTensor (scope, memory.pre_var );
242+
243+ // all the step_inputs' metas should be the same, just randomly select one
244+ // and get the dyseq meta.
245+ const auto & some_meta = dy_seq_metas_[arg_.inlinks .front ()];
246+ size_t num_instances =
247+ step_inputs_[arg_.inlinks .front ()].Read (step).dims ()[0 ];
248+
249+ LoDTensor* pre_state{nullptr };
250+ if (step == 0 ) {
251+ pre_state = cache_.GetTensor (*cache_.scope , memory.boot_var );
252+ pre_state->mutable_data <float >(platform::CPUPlace ());
253+ // allocate memory
254+ state_pre.Resize (pre_state->dims ());
255+ state_pre.mutable_data <value_type>(platform::CPUPlace ());
256+ detail::ReorderBootState<value_type>(some_meta, *pre_state, &state_pre,
257+ pre_state->place ());
258+ } else {
259+ pre_state = cache_.GetTensor (cache_.GetScope (step - 1 ), memory.var );
260+ }
261+
262+ // shink and share from previous state
263+ auto shrinked_pre_state = pre_state->Slice <value_type>(0 , num_instances);
264+ state_pre.ShareDataWith <value_type>(shrinked_pre_state);
265+ }
266+
215267void DynamicRecurrentOp::ArgCache::Init (
216268 const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op,
217269 const paddle::framework::Scope& scope, rnn::Argument* arg) {
@@ -261,6 +313,12 @@ Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope,
261313 return var;
262314}
263315
316+ LoDTensor* DynamicRecurrentOp::ArgCache::GetTensor (
317+ const framework::Scope& scope, const std::string& name) {
318+ auto * var = GetVariable (scope, name);
319+ return var->GetMutable <LoDTensor>();
320+ }
321+
264322const rnn::ArgumentName DynamicRecurrentOp::kArgName {
265323 " step_net" , " step_scopes" , " inlinks" , " outlinks" ,
266324 " memories" , " pre_memories" , " boot_memories" };
0 commit comments