Skip to content

Gradient Clipping #902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ num-traits = { workspace = true }
safetensors = { workspace = true, optional = true }
memmap2 = { workspace = true, optional = true }
half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] }
gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] }
gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] }
rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true }
Expand Down
1 change: 1 addition & 0 deletions dfdx-core/src/data/collate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl<A, B> Collate for Vec<(A, B)> {
impl<'a, A, B> Collate for Vec<&'a (A, B)> {
type Collated = (Vec<&'a A>, Vec<&'a B>);
fn collated(self) -> Self::Collated {
#[allow(clippy::map_identity)]
self.into_iter().map(|(a, b)| (a, b)).unzip()
}
}
Expand Down
38 changes: 0 additions & 38 deletions dfdx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,44 +128,6 @@ pub mod prelude {
pub use crate::tensor_ops::*;
}

/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn flush_denormals_to_zero() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}
}

/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn keep_denormals() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}
}

#[cfg(test)]
pub(crate) mod tests {
pub use num_traits::{Float, NumCast, Zero};
Expand Down
129 changes: 129 additions & 0 deletions dfdx-core/src/nn_traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,135 @@ pub trait ZeroGrads<E: Dtype, D: Device<E>> {
}
}

/// Something that can view or mutate a [Gradients] object.
pub trait WithGrads<E: Dtype, D: Device<E>> {
/// View the gradient values for each parameter.
fn grads_element_view<F: FnMut(&E)>(&self, grads: &Gradients<E, D>, f: F) {
self.try_grads_element_view(grads, f).unwrap()
}
/// View the gradient values for each parameter.
fn try_grads_element_view<F: FnMut(&E)>(
&self,
grads: &Gradients<E, D>,
f: F,
) -> Result<(), Error>;
/// View the gradient values for each tensor (unique id).
fn grads_view<F: FnMut(&[E])>(&self, grads: &Gradients<E, D>, f: F) {
self.try_grads_view(grads, f).unwrap()
}
/// View the gradient values for each tensor (unique id).
fn try_grads_view<F: FnMut(&[E])>(&self, grads: &Gradients<E, D>, f: F) -> Result<(), Error>;
/// Mutate the gradient values for each parameter.
fn grads_element_map<F: FnMut(E) -> E>(&self, grads: &mut Gradients<E, D>, f: F) {
self.try_grads_element_map(grads, f).unwrap()
}
/// Mutate the gradient values for each parameter.
fn try_grads_element_map<F: FnMut(E) -> E>(
&self,
grads: &mut Gradients<E, D>,
f: F,
) -> Result<(), crate::tensor::Error>;
/// Mutate the gradient values for each tensor (unique id).
fn grads_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(&self, grads: &mut Gradients<E, D>, f: F) {
self.try_grads_map(grads, f).unwrap()
}
/// Mutate the gradient values for each tensor (unique id).
fn try_grads_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
&self,
grads: &mut Gradients<E, D>,
f: F,
) -> Result<(), crate::tensor::Error>;
/// Changes the gradient values for each parameter to be between `min` and `max`.
///
/// Note that this may change the "direction" of your gradients.
fn grads_clamp(&self, grads: &mut Gradients<E, D>, min: E, max: E)
where
E: std::cmp::PartialOrd + Clone,
{
self.try_grads_clamp(grads, min, max).unwrap()
}
/// Changes the gradient values for each parameter to be between `min` and `max`.
///
/// Note that this may change the "direction" of your gradients.
fn try_grads_clamp(&self, grads: &mut Gradients<E, D>, min: E, max: E) -> Result<(), Error>
where
E: std::cmp::PartialOrd + Clone,
{
self.try_grads_element_map(grads, |e| {
if e < min {
min
} else if e > max {
max
} else {
e
}
})
}
/// Changes the gradient values for each parameter to be between `-threshold` and `+threshold`.
///
/// Note that this may change the "direction" of your gradients.
fn grads_clip_value(&self, grads: &mut Gradients<E, D>, threshold: E)
where
E: std::cmp::PartialOrd + std::ops::Neg<Output = E> + Clone,
{
self.try_grads_clip_value(grads, threshold).unwrap()
}
/// Changes the gradient values for each parameter to be between `-threshold` and `+threshold`.
///
/// Note that this may change the "direction" of your gradients.
fn try_grads_clip_value(&self, grads: &mut Gradients<E, D>, threshold: E) -> Result<(), Error>
where
E: std::cmp::PartialOrd + std::ops::Neg<Output = E> + Clone,
{
self.try_grads_clamp(grads, -threshold, threshold)
}
/// Accumulates into `acc` the squared value for the gradients.
///
/// After the accumulation, taking the sqrt of `acc` results in the gradients norm.
fn grads_norm_squared(&self, grads: &Gradients<E, D>, acc: &mut E)
where
E: num_traits::Zero + std::ops::Mul<Output = E> + num_traits::Float,
{
self.try_grads_norm_squared(grads, acc).unwrap()
}
/// Accumulates into `acc` the squared value for the gradients.
///
/// After the accumulation, taking the sqrt of `acc` results in the gradients norm.
fn try_grads_norm_squared(&self, grads: &Gradients<E, D>, acc: &mut E) -> Result<(), Error>
where
E: std::ops::Mul<Output = E> + num_traits::Float,
{
self.try_grads_element_view(grads, |e| *acc += *e * *e)
}
/// Given a `norm` for all of the gradient values, scales down all gradients so their norm is not higher than `norm_threshold`.
///
/// Note that this doesn't change the "direction" of your gradients.
fn grads_clip_norm(&self, grads: &mut Gradients<E, D>, norm: E, norm_threshold: E)
where
E: Clone + std::cmp::PartialOrd + std::ops::Mul<Output = E> + std::ops::Div<Output = E>,
{
self.try_grads_clip_norm(grads, norm, norm_threshold)
.unwrap()
}
/// Given a `norm` for all of the gradient values, scales down all gradients so their norm is not higher than `norm_threshold`.
///
/// Note that this doesn't change the "direction" of your gradients.
fn try_grads_clip_norm(
&self,
grads: &mut Gradients<E, D>,
norm: E,
norm_threshold: E,
) -> Result<(), Error>
where
E: Clone + std::cmp::PartialOrd + std::ops::Mul<Output = E> + std::ops::Div<Output = E>,
{
if norm > norm_threshold {
self.try_grads_element_map(grads, |e| norm_threshold * e / norm)?
}
Ok(())
}
}

#[cfg(feature = "safetensors")]
/// Something that can be saved to a .safetensors file.
pub trait SaveSafeTensors {
Expand Down
19 changes: 19 additions & 0 deletions dfdx-core/src/nn_traits/tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ macro_rules! tuple_impls {
}
}

impl<Dev: Device<Elem>, Elem: Dtype, $($name: crate::nn_traits::WithGrads<Elem, Dev>),+> crate::nn_traits::WithGrads<Elem, Dev> for ($($name,)+) {
fn try_grads_element_view<F: FnMut(&Elem)>(&self, grads: &crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
$(self.$idx.try_grads_element_view(grads, &mut f)?;)+
Ok(())
}
fn try_grads_view<F: FnMut(&[Elem])>(&self, grads: &crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
$(self.$idx.try_grads_view(grads, &mut f)?;)+
Ok(())
}
fn try_grads_element_map<F: FnMut(Elem) -> Elem>(&self, grads: &mut crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
$(self.$idx.try_grads_element_map(grads, &mut f)?;)+
Ok(())
}
fn try_grads_map<F: FnMut(Vec<Elem>) -> Option<Vec<Elem>>>(&self, grads: &mut crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
$(self.$idx.try_grads_map(grads, &mut f)?;)+
Ok(())
}
}

/*This macro expands like this for a 4-tuple:

impl<
Expand Down
45 changes: 45 additions & 0 deletions dfdx-core/src/nn_traits/vecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,51 @@ impl<E: Dtype, D: Device<E>, T: crate::nn_traits::ZeroGrads<E, D>> crate::nn_tra
}
}

impl<E: Dtype, D: Device<E>, T: crate::nn_traits::WithGrads<E, D>> crate::nn_traits::WithGrads<E, D>
for Vec<T>
{
fn try_grads_element_view<F: FnMut(&E)>(
&self,
grads: &crate::tensor::Gradients<E, D>,
mut f: F,
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter() {
m_i.try_grads_element_view(grads, &mut f)?;
}
Ok(())
}
fn try_grads_view<F: FnMut(&[E])>(
&self,
grads: &crate::tensor::Gradients<E, D>,
mut f: F,
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter() {
m_i.try_grads_view(grads, &mut f)?;
}
Ok(())
}
fn try_grads_element_map<F: FnMut(E) -> E>(
&self,
grads: &mut crate::tensor::Gradients<E, D>,
mut f: F,
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter() {
m_i.try_grads_element_map(grads, &mut f)?;
}
Ok(())
}
fn try_grads_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
&self,
grads: &mut crate::tensor::Gradients<E, D>,
mut f: F,
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter() {
m_i.try_grads_map(grads, &mut f)?;
}
Ok(())
}
}

#[cfg(feature = "safetensors")]
impl<T: crate::nn_traits::SaveSafeTensors> crate::nn_traits::SaveSafeTensors for Vec<T> {
fn write_safetensors(
Expand Down
42 changes: 42 additions & 0 deletions dfdx-core/src/tensor/cpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,48 @@ impl<E: Unit> ZeroFillStorage<E> for Cpu {
}
}

impl<E: Unit> WithStorage<E> for Cpu {
/// View the values by each element (in-place).
fn try_element_view<F: FnMut(&E)>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
for e in storage.iter() {
f(e);
}
Ok(())
}
/// View the values by a [Vec] (in-place).
fn try_view<F: FnMut(&[E])>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
f(storage.data.as_slice());
Ok(())
}
/// Mutates the values by each element (in-place).
fn try_element_map<F: FnMut(E) -> E>(
&self,
storage: &mut Self::Vec,
mut f: F,
) -> Result<(), Error> {
for e in storage.iter_mut() {
let fe = f(*e);
*e = fe;
}
Ok(())
}
/// Mutates a clone of the values (not in-place).
///
/// If `Some` is returned, replaces the changed values back into the object.
/// Otherwise if `None` is returned, the changed values are discarded and the object stays intact.
fn try_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
&self,
storage: &mut Self::Vec,
mut f: F,
) -> Result<(), Error> {
let storage_copy = storage.data.clone();
if let Some(fstorage) = f(storage_copy) {
storage.data.copy_from_slice(&fstorage);
}
Ok(())
}
}

impl<E: Unit> OnesTensor<E> for Cpu {
fn try_ones_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
Expand Down
47 changes: 47 additions & 0 deletions dfdx-core/src/tensor/cuda/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,53 @@ impl<E: Unit> ZeroFillStorage<E> for Cuda {
}
}

impl<E: Unit> WithStorage<E> for Cuda {
/// View a copy of the values by each element (not in-place).
fn try_element_view<F: FnMut(&E)>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
let v = self.dev.dtoh_sync_copy(storage)?;
for e in v.iter() {
f(e);
}
Ok(())
}
/// View a copy of the values by a [Vec] (not in-place).
fn try_view<F: FnMut(&[E])>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
let v = self.dev.dtoh_sync_copy(storage)?;
f(v.as_slice());
Ok(())
}
/// Mutates a copy of the values by each element (not in-place).
/// Then the values in Cuda memory are replaced by the changed values.
fn try_element_map<F: FnMut(E) -> E>(
&self,
storage: &mut Self::Vec,
mut f: F,
) -> Result<(), Error> {
let mut v = self.dev.dtoh_sync_copy(storage)?;
for e in v.iter_mut() {
let fe = (&mut f)(*e);
*e = fe;
}
self.dev.htod_copy_into(v, storage)?;
Ok(())
}
/// Mutates a copy of the values (not in-place).
///
/// If `Some` is returned, the values in Cuda memory are replaced by the changed values.
/// Otherwise if `None` is returned, the values in Cuda memory are left intact.
fn try_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
&self,
storage: &mut Self::Vec,
mut f: F,
) -> Result<(), Error> {
let v = self.dev.dtoh_sync_copy(storage)?;
if let Some(fv) = (&mut f)(v) {
self.dev.htod_copy_into(fv, storage)?;
}
Ok(())
}
}

impl<E: Unit> OnesTensor<E> for Cuda
where
Cpu: OnesTensor<E>,
Expand Down
Loading