Skip to content

Commit ee9ee56

Browse files
authored
Merge pull request #2972 from jacquesqiao/fix-sgd-op
update tensor usage in sgd-op
2 parents 7a0e772 + e4984f1 commit ee9ee56

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

paddle/operators/sgd_op.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include "glog/logging.h"
17+
#include "paddle/framework/eigen.h"
1718
#include "paddle/framework/operator.h"
1819

1920
namespace paddle {
@@ -30,8 +31,10 @@ class SGDOpKernel : public framework::OpKernel {
3031

3132
param_out->mutable_data<T>(ctx.GetPlace());
3233

33-
param_out->flat<T>().device(*(ctx.GetEigenDevice<Place>())) =
34-
param.flat<T>() - lr * grad.flat<T>();
34+
framework::EigenVector<T>::Flatten(*param_out)
35+
.device(*(ctx.GetEigenDevice<Place>())) =
36+
framework::EigenVector<T>::Flatten(param) -
37+
lr * framework::EigenVector<T>::Flatten(grad);
3538
}
3639
};
3740

0 commit comments

Comments
 (0)