@@ -16,7 +16,6 @@ using exec_aten::Tensor;
16
16
using exec_aten::TensorImpl;
17
17
using ::executorch::runtime::Error;
18
18
using ::executorch::runtime::KernelRuntimeContext;
19
- using ::executorch::runtime::Span;
20
19
21
20
namespace executorch {
22
21
namespace extension {
@@ -39,25 +38,13 @@ void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
39
38
options_ = std::move (options);
40
39
}
41
40
42
- Span<const char *> SGDParamGroup::param_names () {
43
- return param_names_;
44
- }
45
-
46
- const Span<const char *> SGDParamGroup::param_names () const {
47
- return param_names_;
48
- }
49
-
50
- Span<Tensor> SGDParamGroup::param_data () {
51
- return param_data_;
52
- }
53
-
54
- const Span<Tensor> SGDParamGroup::param_data () const {
55
- return param_data_;
41
+ const std::map<exec_aten::string_view, exec_aten::Tensor>&
42
+ SGDParamGroup::named_parameters () const {
43
+ return named_parameters_;
56
44
}
57
45
58
46
void SGD::add_param_group (const SGDParamGroup& param_group) {
59
- SGDParamGroup param_group_ (
60
- param_group.param_names (), param_group.param_data ());
47
+ SGDParamGroup param_group_ (param_group.named_parameters ());
61
48
if (!param_group.has_options ()) {
62
49
param_group_.set_options (defaults_->clone ());
63
50
} else {
@@ -66,13 +53,8 @@ void SGD::add_param_group(const SGDParamGroup& param_group) {
66
53
param_groups_.emplace_back (std::move (param_group_));
67
54
}
68
55
69
- Error SGD::step (Span<const char *> gradient_names, Span<Tensor> gradient_data) {
70
- // check that the number of gradient names matches the number of gradients
71
- ET_CHECK_OR_RETURN_ERROR (
72
- gradient_names.size () == gradient_data.size (),
73
- InvalidState,
74
- " Gradient names and gradients must have the same length." );
75
-
56
+ Error SGD::step (const std::map<exec_aten::string_view, exec_aten::Tensor>&
57
+ named_gradients) {
76
58
KernelRuntimeContext context;
77
59
for (auto & group : param_groups_) {
78
60
auto & options = static_cast <SGDOptions&>(group.options ());
@@ -81,85 +63,82 @@ Error SGD::step(Span<const char*> gradient_names, Span<Tensor> gradient_data) {
81
63
auto dampening = options.dampening ();
82
64
auto nesterov = options.nesterov ();
83
65
84
- for (int i = 0 ; i < group.param_names ().size (); i++) {
85
- for (int j = 0 ; j < gradient_names.size (); j++) {
86
- // if param name and gradient name match, run the optimizer step
87
- if (strcmp (group.param_names ()[i], gradient_names[j]) == 0 ) {
88
- auto d_p = gradient_data[j];
89
- auto p = group.param_data ()[i];
90
- if (weight_decay != 0 ) {
91
- // uses weight_decay specified and adds it to the gradient
92
- torch::executor::aten::add_outf (context, d_p, p, weight_decay, d_p);
93
- if (context.failure_state () != Error::Ok) {
94
- return context.failure_state ();
95
- }
66
+ for (auto param_iter = group.named_parameters ().begin ();
67
+ param_iter != group.named_parameters ().end ();
68
+ ++param_iter) {
69
+ // if param name and gradient name match, run the optimizer step
70
+ const auto & named_gradient = named_gradients.find (param_iter->first );
71
+ if (named_gradient != named_gradients.end ()) {
72
+ auto d_p = named_gradient->second ;
73
+ auto p = param_iter->second ;
74
+ if (weight_decay != 0 ) {
75
+ // uses weight_decay specified and adds it to the gradient
76
+ torch::executor::aten::add_outf (context, d_p, p, weight_decay, d_p);
77
+ if (context.failure_state () != Error::Ok) {
78
+ return context.failure_state ();
96
79
}
97
- if (momentum != 0 ) {
98
- Tensor buf (nullptr );
99
- auto param_state = state_.find (p.unsafeGetTensorImpl ());
100
- // look for the momentum buffer for the given parameter. this is the
101
- // momentum as of the previous epoch
102
- if (param_state == state_.end ()) {
103
- // create a new momentum buffer if it doesn't exist. this memory
104
- // needs to be freed when the optimizer is destroyed
105
- void * buf_ptr = malloc (d_p.nbytes ());
80
+ }
81
+ if (momentum != 0 ) {
82
+ Tensor buf (nullptr );
83
+ auto param_state = state_.find (p.unsafeGetTensorImpl ());
84
+ // look for the momentum buffer for the given parameter. this is the
85
+ // momentum as of the previous epoch
86
+ if (param_state == state_.end ()) {
87
+ // create a new momentum buffer if it doesn't exist. this memory
88
+ // needs to be freed when the optimizer is destroyed
89
+ void * buf_ptr = malloc (d_p.nbytes ());
106
90
107
91
#ifdef USE_ATEN_LIB
108
- std::vector<int64_t > sizes (
109
- d_p.sizes ().begin (), d_p.sizes ().end ());
110
- buf = torch::from_blob (buf_ptr, sizes, d_p.scalar_type ());
92
+ std::vector<int64_t > sizes (d_p.sizes ().begin (), d_p.sizes ().end ());
93
+ buf = torch::from_blob (buf_ptr, sizes, d_p.scalar_type ());
111
94
#else
112
- TensorImpl* buf_impl = new TensorImpl (
113
- d_p.scalar_type (),
114
- d_p.sizes ().size (),
115
- const_cast <TensorImpl::SizesType*>(d_p.sizes ().data ()),
116
- buf_ptr,
117
- const_cast <TensorImpl::DimOrderType*>(
118
- d_p.dim_order ().data ()));
119
- buf = Tensor (buf_impl);
95
+ TensorImpl* buf_impl = new TensorImpl (
96
+ d_p.scalar_type (),
97
+ d_p.sizes ().size (),
98
+ const_cast <TensorImpl::SizesType*>(d_p.sizes ().data ()),
99
+ buf_ptr,
100
+ const_cast <TensorImpl::DimOrderType*>(d_p.dim_order ().data ()));
101
+ buf = Tensor (buf_impl);
120
102
#endif
121
- torch::executor::aten::clone_outf (
122
- context, d_p, exec_aten::MemoryFormat::Contiguous, buf);
123
- if (context.failure_state () != Error::Ok) {
124
- return context.failure_state ();
125
- }
126
-
127
- // save the state of the momentum buffer to be reused in later
128
- // epochs
129
- auto state = std::make_unique<SGDParamState>(buf);
130
- state_[p.unsafeGetTensorImpl ()] = std::move (state);
131
- } else {
132
- buf = static_cast <SGDParamState&>(*param_state->second )
133
- .momentum_buffer ();
134
-
135
- // update the momentum buffer and apply dampening
136
- torch::executor::aten::mul_outf (context, buf, momentum, buf);
137
- if (context.failure_state () != Error::Ok) {
138
- return context.failure_state ();
139
- }
140
- torch::executor::aten::add_outf (
141
- context, buf, d_p, 1 - dampening, buf);
142
- if (context.failure_state () != Error::Ok) {
143
- return context.failure_state ();
144
- }
103
+ torch::executor::aten::clone_outf (
104
+ context, d_p, exec_aten::MemoryFormat::Contiguous, buf);
105
+ if (context.failure_state () != Error::Ok) {
106
+ return context.failure_state ();
145
107
}
146
- if (nesterov) {
147
- // apply nesterov momentum
148
- torch::executor::aten::add_outf (context, d_p, buf, momentum, d_p);
149
- if (context.failure_state () != Error::Ok) {
150
- return context.failure_state ();
151
- }
152
- } else {
153
- d_p = buf;
108
+
109
+ // save the state of the momentum buffer to be reused in later
110
+ // epochs
111
+ auto state = std::make_unique<SGDParamState>(buf);
112
+ state_[p.unsafeGetTensorImpl ()] = std::move (state);
113
+ } else {
114
+ buf = static_cast <SGDParamState&>(*param_state->second )
115
+ .momentum_buffer ();
116
+
117
+ // update the momentum buffer and apply dampening
118
+ torch::executor::aten::mul_outf (context, buf, momentum, buf);
119
+ if (context.failure_state () != Error::Ok) {
120
+ return context.failure_state ();
121
+ }
122
+ torch::executor::aten::add_outf (
123
+ context, buf, d_p, 1 - dampening, buf);
124
+ if (context.failure_state () != Error::Ok) {
125
+ return context.failure_state ();
154
126
}
155
127
}
156
- // update the parameter using the gradient and learning rate
157
- torch::executor::aten::add_outf (
158
- context, p, d_p, -1 * options.lr (), p);
159
- if (context.failure_state () != Error::Ok) {
160
- return context.failure_state ();
128
+ if (nesterov) {
129
+ // apply nesterov momentum
130
+ torch::executor::aten::add_outf (context, d_p, buf, momentum, d_p);
131
+ if (context.failure_state () != Error::Ok) {
132
+ return context.failure_state ();
133
+ }
134
+ } else {
135
+ d_p = buf;
161
136
}
162
- break ;
137
+ }
138
+ // update the parameter using the gradient and learning rate
139
+ torch::executor::aten::add_outf (context, p, d_p, -1 * options.lr (), p);
140
+ if (context.failure_state () != Error::Ok) {
141
+ return context.failure_state ();
163
142
}
164
143
}
165
144
}
0 commit comments