@@ -16,7 +16,6 @@ using exec_aten::Tensor;
1616using exec_aten::TensorImpl;
1717using ::executorch::runtime::Error;
1818using ::executorch::runtime::KernelRuntimeContext;
19- using ::executorch::runtime::Span;
2019
2120namespace executorch {
2221namespace extension {
@@ -39,25 +38,13 @@ void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
3938 options_ = std::move (options);
4039}
4140
42- Span<const char *> SGDParamGroup::param_names () {
43- return param_names_;
44- }
45-
46- const Span<const char *> SGDParamGroup::param_names () const {
47- return param_names_;
48- }
49-
50- Span<Tensor> SGDParamGroup::param_data () {
51- return param_data_;
52- }
53-
54- const Span<Tensor> SGDParamGroup::param_data () const {
55- return param_data_;
41+ const std::map<exec_aten::string_view, exec_aten::Tensor>&
42+ SGDParamGroup::named_parameters () const {
43+ return named_parameters_;
5644}
5745
5846void SGD::add_param_group (const SGDParamGroup& param_group) {
59- SGDParamGroup param_group_ (
60- param_group.param_names (), param_group.param_data ());
47+ SGDParamGroup param_group_ (param_group.named_parameters ());
6148 if (!param_group.has_options ()) {
6249 param_group_.set_options (defaults_->clone ());
6350 } else {
@@ -66,13 +53,8 @@ void SGD::add_param_group(const SGDParamGroup& param_group) {
6653 param_groups_.emplace_back (std::move (param_group_));
6754}
6855
69- Error SGD::step (Span<const char *> gradient_names, Span<Tensor> gradient_data) {
70- // check that the number of gradient names matches the number of gradients
71- ET_CHECK_OR_RETURN_ERROR (
72- gradient_names.size () == gradient_data.size (),
73- InvalidState,
74- " Gradient names and gradients must have the same length." );
75-
56+ Error SGD::step (const std::map<exec_aten::string_view, exec_aten::Tensor>&
57+ named_gradients) {
7658 KernelRuntimeContext context;
7759 for (auto & group : param_groups_) {
7860 auto & options = static_cast <SGDOptions&>(group.options ());
@@ -81,85 +63,82 @@ Error SGD::step(Span<const char*> gradient_names, Span<Tensor> gradient_data) {
8163 auto dampening = options.dampening ();
8264 auto nesterov = options.nesterov ();
8365
84- for (int i = 0 ; i < group.param_names ().size (); i++) {
85- for (int j = 0 ; j < gradient_names.size (); j++) {
86- // if param name and gradient name match, run the optimizer step
87- if (strcmp (group.param_names ()[i], gradient_names[j]) == 0 ) {
88- auto d_p = gradient_data[j];
89- auto p = group.param_data ()[i];
90- if (weight_decay != 0 ) {
91- // uses weight_decay specified and adds it to the gradient
92- torch::executor::aten::add_outf (context, d_p, p, weight_decay, d_p);
93- if (context.failure_state () != Error::Ok) {
94- return context.failure_state ();
95- }
66+ for (auto param_iter = group.named_parameters ().begin ();
67+ param_iter != group.named_parameters ().end ();
68+ ++param_iter) {
69+ // if param name and gradient name match, run the optimizer step
70+ const auto & named_gradient = named_gradients.find (param_iter->first );
71+ if (named_gradient != named_gradients.end ()) {
72+ auto d_p = named_gradient->second ;
73+ auto p = param_iter->second ;
74+ if (weight_decay != 0 ) {
75+ // uses weight_decay specified and adds it to the gradient
76+ torch::executor::aten::add_outf (context, d_p, p, weight_decay, d_p);
77+ if (context.failure_state () != Error::Ok) {
78+ return context.failure_state ();
9679 }
97- if (momentum != 0 ) {
98- Tensor buf (nullptr );
99- auto param_state = state_.find (p.unsafeGetTensorImpl ());
100- // look for the momentum buffer for the given parameter. this is the
101- // momentum as of the previous epoch
102- if (param_state == state_.end ()) {
103- // create a new momentum buffer if it doesn't exist. this memory
104- // needs to be freed when the optimizer is destroyed
105- void * buf_ptr = malloc (d_p.nbytes ());
80+ }
81+ if (momentum != 0 ) {
82+ Tensor buf (nullptr );
83+ auto param_state = state_.find (p.unsafeGetTensorImpl ());
84+ // look for the momentum buffer for the given parameter. this is the
85+ // momentum as of the previous epoch
86+ if (param_state == state_.end ()) {
87+ // create a new momentum buffer if it doesn't exist. this memory
88+ // needs to be freed when the optimizer is destroyed
89+ void * buf_ptr = malloc (d_p.nbytes ());
10690
10791#ifdef USE_ATEN_LIB
108- std::vector<int64_t > sizes (
109- d_p.sizes ().begin (), d_p.sizes ().end ());
110- buf = torch::from_blob (buf_ptr, sizes, d_p.scalar_type ());
92+ std::vector<int64_t > sizes (d_p.sizes ().begin (), d_p.sizes ().end ());
93+ buf = torch::from_blob (buf_ptr, sizes, d_p.scalar_type ());
11194#else
112- TensorImpl* buf_impl = new TensorImpl (
113- d_p.scalar_type (),
114- d_p.sizes ().size (),
115- const_cast <TensorImpl::SizesType*>(d_p.sizes ().data ()),
116- buf_ptr,
117- const_cast <TensorImpl::DimOrderType*>(
118- d_p.dim_order ().data ()));
119- buf = Tensor (buf_impl);
95+ TensorImpl* buf_impl = new TensorImpl (
96+ d_p.scalar_type (),
97+ d_p.sizes ().size (),
98+ const_cast <TensorImpl::SizesType*>(d_p.sizes ().data ()),
99+ buf_ptr,
100+ const_cast <TensorImpl::DimOrderType*>(d_p.dim_order ().data ()));
101+ buf = Tensor (buf_impl);
120102#endif
121- torch::executor::aten::clone_outf (
122- context, d_p, exec_aten::MemoryFormat::Contiguous, buf);
123- if (context.failure_state () != Error::Ok) {
124- return context.failure_state ();
125- }
126-
127- // save the state of the momentum buffer to be reused in later
128- // epochs
129- auto state = std::make_unique<SGDParamState>(buf);
130- state_[p.unsafeGetTensorImpl ()] = std::move (state);
131- } else {
132- buf = static_cast <SGDParamState&>(*param_state->second )
133- .momentum_buffer ();
134-
135- // update the momentum buffer and apply dampening
136- torch::executor::aten::mul_outf (context, buf, momentum, buf);
137- if (context.failure_state () != Error::Ok) {
138- return context.failure_state ();
139- }
140- torch::executor::aten::add_outf (
141- context, buf, d_p, 1 - dampening, buf);
142- if (context.failure_state () != Error::Ok) {
143- return context.failure_state ();
144- }
103+ torch::executor::aten::clone_outf (
104+ context, d_p, exec_aten::MemoryFormat::Contiguous, buf);
105+ if (context.failure_state () != Error::Ok) {
106+ return context.failure_state ();
145107 }
146- if (nesterov) {
147- // apply nesterov momentum
148- torch::executor::aten::add_outf (context, d_p, buf, momentum, d_p);
149- if (context.failure_state () != Error::Ok) {
150- return context.failure_state ();
151- }
152- } else {
153- d_p = buf;
108+
109+ // save the state of the momentum buffer to be reused in later
110+ // epochs
111+ auto state = std::make_unique<SGDParamState>(buf);
112+ state_[p.unsafeGetTensorImpl ()] = std::move (state);
113+ } else {
114+ buf = static_cast <SGDParamState&>(*param_state->second )
115+ .momentum_buffer ();
116+
117+ // update the momentum buffer and apply dampening
118+ torch::executor::aten::mul_outf (context, buf, momentum, buf);
119+ if (context.failure_state () != Error::Ok) {
120+ return context.failure_state ();
121+ }
122+ torch::executor::aten::add_outf (
123+ context, buf, d_p, 1 - dampening, buf);
124+ if (context.failure_state () != Error::Ok) {
125+ return context.failure_state ();
154126 }
155127 }
156- // update the parameter using the gradient and learning rate
157- torch::executor::aten::add_outf (
158- context, p, d_p, -1 * options.lr (), p);
159- if (context.failure_state () != Error::Ok) {
160- return context.failure_state ();
128+ if (nesterov) {
129+ // apply nesterov momentum
130+ torch::executor::aten::add_outf (context, d_p, buf, momentum, d_p);
131+ if (context.failure_state () != Error::Ok) {
132+ return context.failure_state ();
133+ }
134+ } else {
135+ d_p = buf;
161136 }
162- break ;
137+ }
138+ // update the parameter using the gradient and learning rate
139+ torch::executor::aten::add_outf (context, p, d_p, -1 * options.lr (), p);
140+ if (context.failure_state () != Error::Ok) {
141+ return context.failure_state ();
163142 }
164143 }
165144 }
0 commit comments