Skip to content

Commit 7a21ba7

Browse files
committed
Adds Storage and Gradient view/mutating methods; Adds grads clamping and cliping
- Added `dfdx::nn_traits::WithGrads` trait and `dfdx_derives::WithGrads` proc macro, basead on `ZeroGrads`. - Added `dfdx_core::tensor::WithStorage` trait. - Changed some methods from `Gradients`: - Exposed `get_mut` as `pub`. - Exposed `get_ref` as `pub`, and lower the requirements from `&mut self` to `&self`. - Added gradient clamping and cliping methods.
1 parent 1175903 commit 7a21ba7

31 files changed

+536
-28
lines changed

dfdx-core/src/nn_traits/mod.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,135 @@ pub trait ZeroGrads<E: Dtype, D: Device<E>> {
113113
}
114114
}
115115

116+
/// Something that can view or mutate a [Gradients] object.
117+
pub trait WithGrads<E: Dtype, D: Device<E>> {
118+
/// View the gradient values for each parameter.
119+
fn grads_element_view<F: FnMut(&E)>(&self, grads: &Gradients<E, D>, f: F) {
120+
self.try_grads_element_view(grads, f).unwrap()
121+
}
122+
/// View the gradient values for each parameter.
123+
fn try_grads_element_view<F: FnMut(&E)>(
124+
&self,
125+
grads: &Gradients<E, D>,
126+
f: F,
127+
) -> Result<(), Error>;
128+
/// View the gradient values for each tensor (unique id).
129+
fn grads_view<F: FnMut(&[E])>(&self, grads: &Gradients<E, D>, f: F) {
130+
self.try_grads_view(grads, f).unwrap()
131+
}
132+
/// View the gradient values for each tensor (unique id).
133+
fn try_grads_view<F: FnMut(&[E])>(&self, grads: &Gradients<E, D>, f: F) -> Result<(), Error>;
134+
/// Mutate the gradient values for each parameter.
135+
fn grads_element_map<F: FnMut(E) -> E>(&self, grads: &mut Gradients<E, D>, f: F) {
136+
self.try_grads_element_map(grads, f).unwrap()
137+
}
138+
/// Mutate the gradient values for each parameter.
139+
fn try_grads_element_map<F: FnMut(E) -> E>(
140+
&self,
141+
grads: &mut Gradients<E, D>,
142+
f: F,
143+
) -> Result<(), crate::tensor::Error>;
144+
/// Mutate the gradient values for each tensor (unique id).
145+
fn grads_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(&self, grads: &mut Gradients<E, D>, f: F) {
146+
self.try_grads_map(grads, f).unwrap()
147+
}
148+
/// Mutate the gradient values for each tensor (unique id).
149+
fn try_grads_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
150+
&self,
151+
grads: &mut Gradients<E, D>,
152+
f: F,
153+
) -> Result<(), crate::tensor::Error>;
154+
/// Changes the gradient values for each parameter to be between `min` and `max`.
155+
///
156+
/// Note that this may change the "direction" of your gradients.
157+
fn grads_clamp(&self, grads: &mut Gradients<E, D>, min: E, max: E)
158+
where
159+
E: std::cmp::PartialOrd + Clone,
160+
{
161+
self.try_grads_clamp(grads, min, max).unwrap()
162+
}
163+
/// Changes the gradient values for each parameter to be between `min` and `max`.
164+
///
165+
/// Note that this may change the "direction" of your gradients.
166+
fn try_grads_clamp(&self, grads: &mut Gradients<E, D>, min: E, max: E) -> Result<(), Error>
167+
where
168+
E: std::cmp::PartialOrd + Clone,
169+
{
170+
self.try_grads_element_map(grads, |e| {
171+
if e < min {
172+
min
173+
} else if e > max {
174+
max
175+
} else {
176+
e
177+
}
178+
})
179+
}
180+
/// Changes the gradient values for each parameter to be between `-threshold` and `+threshold`.
181+
///
182+
/// Note that this may change the "direction" of your gradients.
183+
fn grads_clip_value(&self, grads: &mut Gradients<E, D>, threshold: E)
184+
where
185+
E: std::cmp::PartialOrd + std::ops::Neg<Output = E> + Clone,
186+
{
187+
self.try_grads_clip_value(grads, threshold).unwrap()
188+
}
189+
/// Changes the gradient values for each parameter to be between `-threshold` and `+threshold`.
190+
///
191+
/// Note that this may change the "direction" of your gradients.
192+
fn try_grads_clip_value(&self, grads: &mut Gradients<E, D>, threshold: E) -> Result<(), Error>
193+
where
194+
E: std::cmp::PartialOrd + std::ops::Neg<Output = E> + Clone,
195+
{
196+
self.try_grads_clamp(grads, -threshold, threshold)
197+
}
198+
/// Accumulates into `acc` the squared value for the gradients.
199+
///
200+
/// After the accumulation, taking the sqrt of `acc` results in the gradients norm.
201+
fn grads_norm_squared(&self, grads: &Gradients<E, D>, acc: &mut E)
202+
where
203+
E: num_traits::Zero + std::ops::Mul<Output = E> + num_traits::Float,
204+
{
205+
self.try_grads_norm_squared(grads, acc).unwrap()
206+
}
207+
/// Accumulates into `acc` the squared value for the gradients.
208+
///
209+
/// After the accumulation, taking the sqrt of `acc` results in the gradients norm.
210+
fn try_grads_norm_squared(&self, grads: &Gradients<E, D>, acc: &mut E) -> Result<(), Error>
211+
where
212+
E: std::ops::Mul<Output = E> + num_traits::Float,
213+
{
214+
self.try_grads_element_view(grads, |e| *acc += *e * *e)
215+
}
216+
/// Given a `norm` for all of the gradient values, scales down all gradients so their norm is not higher than `norm_threshold`.
217+
///
218+
/// Note that this doesn't change the "direction" of your gradients.
219+
fn grads_clip_norm(&self, grads: &mut Gradients<E, D>, norm: E, norm_threshold: E)
220+
where
221+
E: Clone + std::cmp::PartialOrd + std::ops::Mul<Output = E> + std::ops::Div<Output = E>,
222+
{
223+
self.try_grads_clip_norm(grads, norm, norm_threshold)
224+
.unwrap()
225+
}
226+
/// Given a `norm` for all of the gradient values, scales down all gradients so their norm is not higher than `norm_threshold`.
227+
///
228+
/// Note that this doesn't change the "direction" of your gradients.
229+
fn try_grads_clip_norm(
230+
&self,
231+
grads: &mut Gradients<E, D>,
232+
norm: E,
233+
norm_threshold: E,
234+
) -> Result<(), Error>
235+
where
236+
E: Clone + std::cmp::PartialOrd + std::ops::Mul<Output = E> + std::ops::Div<Output = E>,
237+
{
238+
if norm > norm_threshold {
239+
self.try_grads_element_map(grads, |e| norm_threshold * e / norm)?
240+
}
241+
Ok(())
242+
}
243+
}
244+
116245
#[cfg(feature = "safetensors")]
117246
/// Something that can be saved to a .safetensors file.
118247
pub trait SaveSafeTensors {

dfdx-core/src/nn_traits/tuples.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,25 @@ macro_rules! tuple_impls {
6767
}
6868
}
6969

70+
impl<Dev: Device<Elem>, Elem: Dtype, $($name: crate::nn_traits::WithGrads<Elem, Dev>),+> crate::nn_traits::WithGrads<Elem, Dev> for ($($name,)+) {
71+
fn try_grads_element_view<F: FnMut(&Elem)>(&self, grads: &crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
72+
$(self.$idx.try_grads_element_view(grads, &mut f)?;)+
73+
Ok(())
74+
}
75+
fn try_grads_view<F: FnMut(&[Elem])>(&self, grads: &crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
76+
$(self.$idx.try_grads_view(grads, &mut f)?;)+
77+
Ok(())
78+
}
79+
fn try_grads_element_map<F: FnMut(Elem) -> Elem>(&self, grads: &mut crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
80+
$(self.$idx.try_grads_element_map(grads, &mut f)?;)+
81+
Ok(())
82+
}
83+
fn try_grads_map<F: FnMut(Vec<Elem>) -> Option<Vec<Elem>>>(&self, grads: &mut crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
84+
$(self.$idx.try_grads_map(grads, &mut f)?;)+
85+
Ok(())
86+
}
87+
}
88+
7089
/*This macro expands like this for a 4-tuple:
7190
7291
impl<

dfdx-core/src/nn_traits/vecs.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,51 @@ impl<E: Dtype, D: Device<E>, T: crate::nn_traits::ZeroGrads<E, D>> crate::nn_tra
5858
}
5959
}
6060

61+
impl<E: Dtype, D: Device<E>, T: crate::nn_traits::WithGrads<E, D>> crate::nn_traits::WithGrads<E, D>
62+
for Vec<T>
63+
{
64+
fn try_grads_element_view<F: FnMut(&E)>(
65+
&self,
66+
grads: &crate::tensor::Gradients<E, D>,
67+
mut f: F,
68+
) -> Result<(), crate::tensor::Error> {
69+
for m_i in self.iter() {
70+
m_i.try_grads_element_view(grads, &mut f)?;
71+
}
72+
Ok(())
73+
}
74+
fn try_grads_view<F: FnMut(&[E])>(
75+
&self,
76+
grads: &crate::tensor::Gradients<E, D>,
77+
mut f: F,
78+
) -> Result<(), crate::tensor::Error> {
79+
for m_i in self.iter() {
80+
m_i.try_grads_view(grads, &mut f)?;
81+
}
82+
Ok(())
83+
}
84+
fn try_grads_element_map<F: FnMut(E) -> E>(
85+
&self,
86+
grads: &mut crate::tensor::Gradients<E, D>,
87+
mut f: F,
88+
) -> Result<(), crate::tensor::Error> {
89+
for m_i in self.iter() {
90+
m_i.try_grads_element_map(grads, &mut f)?;
91+
}
92+
Ok(())
93+
}
94+
fn try_grads_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
95+
&self,
96+
grads: &mut crate::tensor::Gradients<E, D>,
97+
mut f: F,
98+
) -> Result<(), crate::tensor::Error> {
99+
for m_i in self.iter() {
100+
m_i.try_grads_map(grads, &mut f)?;
101+
}
102+
Ok(())
103+
}
104+
}
105+
61106
#[cfg(feature = "safetensors")]
62107
impl<T: crate::nn_traits::SaveSafeTensors> crate::nn_traits::SaveSafeTensors for Vec<T> {
63108
fn write_safetensors(

dfdx-core/src/tensor/cpu/allocate.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,48 @@ impl<E: Unit> ZeroFillStorage<E> for Cpu {
7878
}
7979
}
8080

81+
impl<E: Unit> WithStorage<E> for Cpu {
82+
/// View the values by each element (in-place).
83+
fn try_element_view<F: FnMut(&E)>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
84+
for e in storage.iter() {
85+
f(e);
86+
}
87+
Ok(())
88+
}
89+
/// View the values by a [Vec] (in-place).
90+
fn try_view<F: FnMut(&[E])>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
91+
f(storage.data.as_slice());
92+
Ok(())
93+
}
94+
/// Mutates the values by each element (in-place).
95+
fn try_element_map<F: FnMut(E) -> E>(
96+
&self,
97+
storage: &mut Self::Vec,
98+
mut f: F,
99+
) -> Result<(), Error> {
100+
for e in storage.iter_mut() {
101+
let fe = f(*e);
102+
*e = fe;
103+
}
104+
Ok(())
105+
}
106+
/// Mutates a clone of the values (not in-place).
107+
///
108+
/// If `Some` is returned, replaces the changed values back into the object.
109+
/// Otherwise if `None` is returned, the changed values are discarded and the object stays intact.
110+
fn try_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
111+
&self,
112+
storage: &mut Self::Vec,
113+
mut f: F,
114+
) -> Result<(), Error> {
115+
let storage_copy = storage.data.clone();
116+
if let Some(fstorage) = f(storage_copy) {
117+
storage.data.copy_from_slice(&fstorage);
118+
}
119+
Ok(())
120+
}
121+
}
122+
81123
impl<E: Unit> OnesTensor<E> for Cpu {
82124
fn try_ones_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
83125
let shape = *src.shape();

dfdx-core/src/tensor/cuda/allocate.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,53 @@ impl<E: Unit> ZeroFillStorage<E> for Cuda {
6060
}
6161
}
6262

63+
impl<E: Unit> WithStorage<E> for Cuda {
64+
/// View a copy of the values by each element (not in-place).
65+
fn try_element_view<F: FnMut(&E)>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
66+
let v = self.dev.dtoh_sync_copy(storage)?;
67+
for e in v.iter() {
68+
f(e);
69+
}
70+
Ok(())
71+
}
72+
/// View a copy of the values by a [Vec] (not in-place).
73+
fn try_view<F: FnMut(&[E])>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
74+
let v = self.dev.dtoh_sync_copy(storage)?;
75+
f(v.as_slice());
76+
Ok(())
77+
}
78+
/// Mutates a copy of the values by each element (not in-place).
79+
/// Then the values in Cuda memory are replaced by the changed values.
80+
fn try_element_map<F: FnMut(E) -> E>(
81+
&self,
82+
storage: &mut Self::Vec,
83+
mut f: F,
84+
) -> Result<(), Error> {
85+
let mut v = self.dev.dtoh_sync_copy(storage)?;
86+
for e in v.iter_mut() {
87+
let fe = (&mut f)(*e);
88+
*e = fe;
89+
}
90+
self.dev.htod_copy_into(v, storage)?;
91+
Ok(())
92+
}
93+
/// Mutates a copy of the values (not in-place).
94+
///
95+
/// If `Some` is returned, the values in Cuda memory are replaced by the changed values.
96+
/// Otherwise if `None` is returned, the values in Cuda memory are left intact.
97+
fn try_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
98+
&self,
99+
storage: &mut Self::Vec,
100+
mut f: F,
101+
) -> Result<(), Error> {
102+
let v = self.dev.dtoh_sync_copy(storage)?;
103+
if let Some(fv) = (&mut f)(v) {
104+
self.dev.htod_copy_into(fv, storage)?;
105+
}
106+
Ok(())
107+
}
108+
}
109+
63110
impl<E: Unit> OnesTensor<E> for Cuda
64111
where
65112
Cpu: OnesTensor<E>,

dfdx-core/src/tensor/gradients.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,14 @@ impl<E, D: Storage<E>> Gradients<E, D> {
8686
/// Returns a mutable reference to the data associated with `t`.
8787
///
8888
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
89-
pub(crate) fn get_mut<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &mut D::Vec {
89+
pub fn get_mut<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &mut D::Vec {
9090
self.gradient_by_id.get_mut(&t.id()).unwrap()
9191
}
9292

9393
/// Returns an immutable reference to the data associated with `t`.
9494
///
9595
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
96-
pub(crate) fn get_ref<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &D::Vec {
96+
pub fn get_ref<S: Shape>(&self, t: &impl Tensorlike<S, E, D>) -> &D::Vec {
9797
self.gradient_by_id.get(&t.id()).unwrap()
9898
}
9999

dfdx-core/src/tensor/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ mod tensor_impls;
160160

161161
pub use error::Error;
162162
pub(crate) use ghost::GhostTensor;
163-
pub(crate) use storage_traits::{OneFillStorage, ZeroFillStorage};
163+
pub(crate) use storage_traits::{OneFillStorage, WithStorage, ZeroFillStorage};
164164
pub use tensorlike::Tensorlike;
165165

166166
pub use cpu::Cpu;

0 commit comments

Comments
 (0)