Skip to content

Commit 7c76e03

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Switch Optimizer to std::map (#5230)
Summary: Pull Request resolved: #5230 Switch to map api which is directly compatible with TrainingModule Update the simple end 2 end test to use TrainingModule as well. Reviewed By: davidlin54 Differential Revision: D62453507 fbshipit-source-id: d40929997d42ea827a97f6fb2a1e38250ac298da
1 parent 6328d41 commit 7c76e03

File tree

7 files changed

+147
-225
lines changed

7 files changed

+147
-225
lines changed

extension/training/module/training_module.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ TrainingModule::named_parameters(const std::string& method_name) {
107107

108108
uint64_t param_start = param_res.get()[0].toInt();
109109

110+
auto e = executorch::extension::Module::load_method(method_name);
111+
if (e != runtime::Error::Ok) {
112+
return e;
113+
}
110114
auto& method = methods_.at(method_name).method;
111115

112116
// create dict

extension/training/module/training_module.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ class ET_EXPERIMENTAL TrainingModule final : executorch::extension::Module {
6868
* parameters for.
6969
*
7070
* @returns A Result object containing a map of the fully qualified name to
71-
* parameter tensor, or an error if the method is not a joint graph or has not
72-
* been executed yet.
71+
* parameter tensor, or an error if the method is not a joint graph.
7372
*/
7473
ET_EXPERIMENTAL
7574
runtime::Result<const std::map<exec_aten::string_view, exec_aten::Tensor>>

extension/training/optimizer/sgd.cpp

+73-94
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ using exec_aten::Tensor;
1616
using exec_aten::TensorImpl;
1717
using ::executorch::runtime::Error;
1818
using ::executorch::runtime::KernelRuntimeContext;
19-
using ::executorch::runtime::Span;
2019

2120
namespace executorch {
2221
namespace extension {
@@ -39,25 +38,13 @@ void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
3938
options_ = std::move(options);
4039
}
4140

42-
Span<const char*> SGDParamGroup::param_names() {
43-
return param_names_;
44-
}
45-
46-
const Span<const char*> SGDParamGroup::param_names() const {
47-
return param_names_;
48-
}
49-
50-
Span<Tensor> SGDParamGroup::param_data() {
51-
return param_data_;
52-
}
53-
54-
const Span<Tensor> SGDParamGroup::param_data() const {
55-
return param_data_;
41+
const std::map<exec_aten::string_view, exec_aten::Tensor>&
42+
SGDParamGroup::named_parameters() const {
43+
return named_parameters_;
5644
}
5745

5846
void SGD::add_param_group(const SGDParamGroup& param_group) {
59-
SGDParamGroup param_group_(
60-
param_group.param_names(), param_group.param_data());
47+
SGDParamGroup param_group_(param_group.named_parameters());
6148
if (!param_group.has_options()) {
6249
param_group_.set_options(defaults_->clone());
6350
} else {
@@ -66,13 +53,8 @@ void SGD::add_param_group(const SGDParamGroup& param_group) {
6653
param_groups_.emplace_back(std::move(param_group_));
6754
}
6855

69-
Error SGD::step(Span<const char*> gradient_names, Span<Tensor> gradient_data) {
70-
// check that the number of gradient names matches the number of gradients
71-
ET_CHECK_OR_RETURN_ERROR(
72-
gradient_names.size() == gradient_data.size(),
73-
InvalidState,
74-
"Gradient names and gradients must have the same length.");
75-
56+
Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
57+
named_gradients) {
7658
KernelRuntimeContext context;
7759
for (auto& group : param_groups_) {
7860
auto& options = static_cast<SGDOptions&>(group.options());
@@ -81,85 +63,82 @@ Error SGD::step(Span<const char*> gradient_names, Span<Tensor> gradient_data) {
8163
auto dampening = options.dampening();
8264
auto nesterov = options.nesterov();
8365

84-
for (int i = 0; i < group.param_names().size(); i++) {
85-
for (int j = 0; j < gradient_names.size(); j++) {
86-
// if param name and gradient name match, run the optimizer step
87-
if (strcmp(group.param_names()[i], gradient_names[j]) == 0) {
88-
auto d_p = gradient_data[j];
89-
auto p = group.param_data()[i];
90-
if (weight_decay != 0) {
91-
// uses weight_decay specified and adds it to the gradient
92-
torch::executor::aten::add_outf(context, d_p, p, weight_decay, d_p);
93-
if (context.failure_state() != Error::Ok) {
94-
return context.failure_state();
95-
}
66+
for (auto param_iter = group.named_parameters().begin();
67+
param_iter != group.named_parameters().end();
68+
++param_iter) {
69+
// if param name and gradient name match, run the optimizer step
70+
const auto& named_gradient = named_gradients.find(param_iter->first);
71+
if (named_gradient != named_gradients.end()) {
72+
auto d_p = named_gradient->second;
73+
auto p = param_iter->second;
74+
if (weight_decay != 0) {
75+
// uses weight_decay specified and adds it to the gradient
76+
torch::executor::aten::add_outf(context, d_p, p, weight_decay, d_p);
77+
if (context.failure_state() != Error::Ok) {
78+
return context.failure_state();
9679
}
97-
if (momentum != 0) {
98-
Tensor buf(nullptr);
99-
auto param_state = state_.find(p.unsafeGetTensorImpl());
100-
// look for the momentum buffer for the given parameter. this is the
101-
// momentum as of the previous epoch
102-
if (param_state == state_.end()) {
103-
// create a new momentum buffer if it doesn't exist. this memory
104-
// needs to be freed when the optimizer is destroyed
105-
void* buf_ptr = malloc(d_p.nbytes());
80+
}
81+
if (momentum != 0) {
82+
Tensor buf(nullptr);
83+
auto param_state = state_.find(p.unsafeGetTensorImpl());
84+
// look for the momentum buffer for the given parameter. this is the
85+
// momentum as of the previous epoch
86+
if (param_state == state_.end()) {
87+
// create a new momentum buffer if it doesn't exist. this memory
88+
// needs to be freed when the optimizer is destroyed
89+
void* buf_ptr = malloc(d_p.nbytes());
10690

10791
#ifdef USE_ATEN_LIB
108-
std::vector<int64_t> sizes(
109-
d_p.sizes().begin(), d_p.sizes().end());
110-
buf = torch::from_blob(buf_ptr, sizes, d_p.scalar_type());
92+
std::vector<int64_t> sizes(d_p.sizes().begin(), d_p.sizes().end());
93+
buf = torch::from_blob(buf_ptr, sizes, d_p.scalar_type());
11194
#else
112-
TensorImpl* buf_impl = new TensorImpl(
113-
d_p.scalar_type(),
114-
d_p.sizes().size(),
115-
const_cast<TensorImpl::SizesType*>(d_p.sizes().data()),
116-
buf_ptr,
117-
const_cast<TensorImpl::DimOrderType*>(
118-
d_p.dim_order().data()));
119-
buf = Tensor(buf_impl);
95+
TensorImpl* buf_impl = new TensorImpl(
96+
d_p.scalar_type(),
97+
d_p.sizes().size(),
98+
const_cast<TensorImpl::SizesType*>(d_p.sizes().data()),
99+
buf_ptr,
100+
const_cast<TensorImpl::DimOrderType*>(d_p.dim_order().data()));
101+
buf = Tensor(buf_impl);
120102
#endif
121-
torch::executor::aten::clone_outf(
122-
context, d_p, exec_aten::MemoryFormat::Contiguous, buf);
123-
if (context.failure_state() != Error::Ok) {
124-
return context.failure_state();
125-
}
126-
127-
// save the state of the momentum buffer to be reused in later
128-
// epochs
129-
auto state = std::make_unique<SGDParamState>(buf);
130-
state_[p.unsafeGetTensorImpl()] = std::move(state);
131-
} else {
132-
buf = static_cast<SGDParamState&>(*param_state->second)
133-
.momentum_buffer();
134-
135-
// update the momentum buffer and apply dampening
136-
torch::executor::aten::mul_outf(context, buf, momentum, buf);
137-
if (context.failure_state() != Error::Ok) {
138-
return context.failure_state();
139-
}
140-
torch::executor::aten::add_outf(
141-
context, buf, d_p, 1 - dampening, buf);
142-
if (context.failure_state() != Error::Ok) {
143-
return context.failure_state();
144-
}
103+
torch::executor::aten::clone_outf(
104+
context, d_p, exec_aten::MemoryFormat::Contiguous, buf);
105+
if (context.failure_state() != Error::Ok) {
106+
return context.failure_state();
145107
}
146-
if (nesterov) {
147-
// apply nesterov momentum
148-
torch::executor::aten::add_outf(context, d_p, buf, momentum, d_p);
149-
if (context.failure_state() != Error::Ok) {
150-
return context.failure_state();
151-
}
152-
} else {
153-
d_p = buf;
108+
109+
// save the state of the momentum buffer to be reused in later
110+
// epochs
111+
auto state = std::make_unique<SGDParamState>(buf);
112+
state_[p.unsafeGetTensorImpl()] = std::move(state);
113+
} else {
114+
buf = static_cast<SGDParamState&>(*param_state->second)
115+
.momentum_buffer();
116+
117+
// update the momentum buffer and apply dampening
118+
torch::executor::aten::mul_outf(context, buf, momentum, buf);
119+
if (context.failure_state() != Error::Ok) {
120+
return context.failure_state();
121+
}
122+
torch::executor::aten::add_outf(
123+
context, buf, d_p, 1 - dampening, buf);
124+
if (context.failure_state() != Error::Ok) {
125+
return context.failure_state();
154126
}
155127
}
156-
// update the parameter using the gradient and learning rate
157-
torch::executor::aten::add_outf(
158-
context, p, d_p, -1 * options.lr(), p);
159-
if (context.failure_state() != Error::Ok) {
160-
return context.failure_state();
128+
if (nesterov) {
129+
// apply nesterov momentum
130+
torch::executor::aten::add_outf(context, d_p, buf, momentum, d_p);
131+
if (context.failure_state() != Error::Ok) {
132+
return context.failure_state();
133+
}
134+
} else {
135+
d_p = buf;
161136
}
162-
break;
137+
}
138+
// update the parameter using the gradient and learning rate
139+
torch::executor::aten::add_outf(context, p, d_p, -1 * options.lr(), p);
140+
if (context.failure_state() != Error::Ok) {
141+
return context.failure_state();
163142
}
164143
}
165144
}

extension/training/optimizer/sgd.h

+22-39
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
#include <executorch/runtime/core/error.h>
2020
#include <executorch/runtime/core/exec_aten/exec_aten.h>
21-
#include <executorch/runtime/core/span.h>
21+
#include <map>
2222
#include <memory>
2323
#include <unordered_map>
2424
#include <vector>
@@ -133,52 +133,42 @@ class SGDParamGroup {
133133
// NOTE: In order to store `SGDParamGroup` in a `std::vector`, it has
134134
// to be copy-constructible.
135135
SGDParamGroup(const SGDParamGroup& param_group)
136-
: param_data_(param_group.param_data()),
137-
param_names_(param_group.param_names()),
136+
: named_parameters_(param_group.named_parameters()),
138137
options_(
139138
param_group.has_options() ? param_group.options().clone()
140139
: nullptr) {}
141140
SGDParamGroup& operator=(const SGDParamGroup& param_group) {
142-
this->param_data_ = param_group.param_data();
143-
this->param_names_ = param_group.param_names();
141+
this->named_parameters_ = param_group.named_parameters_;
144142
this->options_ =
145143
param_group.has_options() ? param_group.options().clone() : nullptr;
146144
return *this;
147145
}
148146

149147
/**
150-
* This constructs a SGD param group. We expect that the two spans are of the
151-
* same size, and that for a given param data, its index in param_data is the
152-
* same as its param name in param_name.
148+
* Constructs a SGD param group.
153149
*
154-
* @param[in] param_names The names of the params for this group.
155-
* @param[in] param_data The tensors representing the param data.
150+
* @param[in] named_parameters The parameters to be optimized and their fully
151+
* qualified names.
156152
*/
157153
/* implicit */ SGDParamGroup(
158-
::executorch::runtime::Span<const char*> param_names,
159-
::executorch::runtime::Span<exec_aten::Tensor> param_data)
160-
: param_data_(std::move(param_data)),
161-
param_names_(std::move(param_names)) {}
154+
const std::map<exec_aten::string_view, exec_aten::Tensor>&
155+
named_parameters)
156+
: named_parameters_(named_parameters) {}
162157
SGDParamGroup(
163-
::executorch::runtime::Span<const char*> param_names,
164-
::executorch::runtime::Span<exec_aten::Tensor> param_data,
158+
const std::map<exec_aten::string_view, exec_aten::Tensor>&
159+
named_parameters,
165160
std::unique_ptr<SGDOptions> options)
166-
: param_data_(std::move(param_data)),
167-
param_names_(std::move(param_names)),
168-
options_(std::move(options)) {}
161+
: named_parameters_(named_parameters), options_(std::move(options)) {}
169162

170163
bool has_options() const;
171164
SGDOptions& options();
172165
const SGDOptions& options() const;
173166
void set_options(std::unique_ptr<SGDOptions> options);
174-
::executorch::runtime::Span<const char*> param_names();
175-
const ::executorch::runtime::Span<const char*> param_names() const;
176-
::executorch::runtime::Span<exec_aten::Tensor> param_data();
177-
const ::executorch::runtime::Span<exec_aten::Tensor> param_data() const;
167+
const std::map<exec_aten::string_view, exec_aten::Tensor>& named_parameters()
168+
const;
178169

179170
private:
180-
::executorch::runtime::Span<exec_aten::Tensor> param_data_;
181-
::executorch::runtime::Span<const char*> param_names_;
171+
std::map<exec_aten::string_view, exec_aten::Tensor> named_parameters_;
182172
std::unique_ptr<SGDOptions> options_;
183173
};
184174

@@ -198,11 +188,10 @@ class SGD {
198188
}
199189

200190
explicit SGD(
201-
::executorch::runtime::Span<const char*> param_names,
202-
::executorch::runtime::Span<exec_aten::Tensor> param_data,
191+
const std::map<exec_aten::string_view, exec_aten::Tensor>&
192+
named_parameters,
203193
SGDOptions defaults)
204-
: SGD({SGDParamGroup(std::move(param_names), std::move(param_data))},
205-
defaults) {}
194+
: SGD({SGDParamGroup(named_parameters)}, defaults) {}
206195

207196
// Adds the given param_group to the optimizer's param_group list.
208197
void add_param_group(const SGDParamGroup& param_group);
@@ -212,18 +201,12 @@ class SGD {
212201
/**
213202
* Performs the optimization step.
214203
*
215-
* The two spans must be of the same size. It is expected that the gradient in
216-
* 'gradient_data' at index 'i' represents the gradient calculated in the loss
217-
* function for the parameter with the name in 'gradient_names' at index 'i'.
218-
*
219-
* @param[in] gradient_names The names of the params that matches the gradient
220-
* in 'gradient_data' at the same index.
221-
* @param[in] gradient_data The gradient tensors to be used for optimization
222-
* step.
204+
* @param[in] named_gradients The gradients of the tensors specified by the
205+
* fully qualified name.
223206
*/
224207
::executorch::runtime::Error step(
225-
::executorch::runtime::Span<const char*> gradient_names,
226-
::executorch::runtime::Span<exec_aten::Tensor> gradient_data);
208+
const std::map<exec_aten::string_view, exec_aten::Tensor>&
209+
named_gradients);
227210

228211
private:
229212
std::vector<SGDParamGroup> param_groups_;

0 commit comments

Comments
 (0)