We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 7a0e772 + e4984f1 commit ee9ee56Copy full SHA for ee9ee56
paddle/operators/sgd_op.h
@@ -14,6 +14,7 @@ limitations under the License. */
14
15
#pragma once
16
#include "glog/logging.h"
17
+#include "paddle/framework/eigen.h"
18
#include "paddle/framework/operator.h"
19
20
namespace paddle {
@@ -30,8 +31,10 @@ class SGDOpKernel : public framework::OpKernel {
30
31
32
param_out->mutable_data<T>(ctx.GetPlace());
33
- param_out->flat<T>().device(*(ctx.GetEigenDevice<Place>())) =
34
- param.flat<T>() - lr * grad.flat<T>();
+ framework::EigenVector<T>::Flatten(*param_out)
35
+ .device(*(ctx.GetEigenDevice<Place>())) =
36
+ framework::EigenVector<T>::Flatten(param) -
37
+ lr * framework::EigenVector<T>::Flatten(grad);
38
}
39
};
40
0 commit comments