@@ -17,6 +17,36 @@ limitations under the License. */
1717namespace paddle {
1818namespace operators {
1919
20+ struct BeamSearchDecodeFunctor {
21+ BeamSearchDecodeFunctor (const LoDTensorArray& step_ids,
22+ const LoDTensorArray& step_scores,
23+ LoDTensor* id_tensor, LoDTensor* score_tensor)
24+ : step_ids_(step_ids),
25+ step_scores_ (step_scores),
26+ id_tensor_(id_tensor),
27+ score_tensor_(score_tensor) {}
28+
29+ template <typename T>
30+ void operator ()() const ;
31+
32+ const LoDTensorArray& step_ids_;
33+ const LoDTensorArray& step_scores_;
34+ LoDTensor* id_tensor_;
35+ LoDTensor* score_tensor_;
36+ };
37+
38+ template <typename T>
39+ void BeamSearchDecodeFunctor::operator ()() const {
40+ BeamSearchDecoder<T> beam_search_decoder;
41+ beam_search_decoder.PackAllSteps (step_ids_, step_scores_, id_tensor_,
42+ score_tensor_);
43+ }
44+
45+ template <>
46+ void BeamSearchDecodeFunctor::operator ()<bool>() const {
47+ PADDLE_THROW (" beam search decode op does not support bool!" );
48+ }
49+
2050class BeamSearchDecodeOp : public framework ::OperatorBase {
2151 public:
2252 BeamSearchDecodeOp (const std::string& type,
@@ -45,9 +75,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
4575 LoDTensor* sentenceIds = ctx.Output <LoDTensor>(" SentenceIds" );
4676 LoDTensor* sentenceScores = ctx.Output <LoDTensor>(" SentenceScores" );
4777
48- BeamSearchDecoder< float > beam_search_decoder;
49- beam_search_decoder. PackAllSteps (*ids, *scores, sentenceIds ,
50- sentenceScores);
78+ framework::VisitDataType (
79+ framework::ToDataType (scores-> at ( 0 ). type ()) ,
80+ BeamSearchDecodeFunctor (*ids, *scores, sentenceIds, sentenceScores) );
5181 }
5282};
5383
0 commit comments