Skip to content

Commit ab58a49

Browse files
Add GradientMonitor handler with Welford's algorithm and comprehensive unit tests
1 parent 2d7a8cc commit ab58a49

3 files changed

Lines changed: 576 additions & 0 deletions

File tree

ignite/handlers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ignite.handlers.early_stopping import EarlyStopping
66
from ignite.handlers.ema_handler import EMAHandler
77
from ignite.handlers.fbresearch_logger import FBResearchLogger
8+
from ignite.handlers.grad_monitor import GradMonitor
89
from ignite.handlers.lr_finder import FastaiLRFinder
910
from ignite.handlers.mlflow_logger import MLflowLogger
1011
from ignite.handlers.neptune_logger import NeptuneLogger
@@ -50,6 +51,7 @@
5051
"Timer",
5152
"EarlyStopping",
5253
"TerminateOnNan",
54+
"GradMonitor",
5355
"global_step_from_engine",
5456
"TimeLimit",
5557
"EpochOutputStore",

ignite/handlers/grad_monitor.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
import torch
2+
from typing import Any, Callable, Optional
3+
from ignite.engine import Engine, Events
4+
5+
6+
def _default_spike_detector(mean: float, m2: float, count: int, norm: float, k: float) -> bool:
7+
"""
8+
Default spike detection rule: flags if norm > mean + k * std.
9+
10+
Args:
11+
mean: Running mean of gradient norms.
12+
m2: Running sum of squared deviations (Welford's algorithm).
13+
count: Number of iterations seen so far.
14+
norm: Current gradient norm.
15+
k: Standard deviation multiplier.
16+
17+
Returns:
18+
True if the current norm is a spike, False otherwise.
19+
"""
20+
if count < 2:
21+
return False
22+
std = (m2 / (count - 1)) ** 0.5
23+
return norm > mean + k * std + 1e-6
24+
25+
26+
class GradMonitor:
27+
"""
28+
Monitors the L2 gradient norm each iteration to detect training instability
29+
(e.g. exploding gradients, entropy collapse) using a dynamic threshold.
30+
31+
The handler attaches to ``Events.ITERATION_STARTED``, meaning it reads
32+
gradients computed during the **previous** iteration. The spike flag is
33+
set on ``engine.state.unhealthy_spike`` and can be checked at the start
34+
of the next iteration's train step.
35+
36+
.. warning::
37+
This handler requires that gradients are **not zeroed** at the end of
38+
your train step. If you call ``optimizer.zero_grad()`` at the end of
39+
``train_step``, all gradients will be gone by the time this handler
40+
runs and the norm will always be 0. Instead, call
41+
``optimizer.zero_grad()`` at the **start** of your train step, after
42+
checking the spike flag.
43+
44+
.. warning::
45+
The ``unhealthy_spike`` flag on ``engine.state`` reflects gradients
46+
from the **previous** iteration, not the current one. Design your
47+
train step accordingly.
48+
49+
.. warning::
50+
Call ``attach`` only once per engine instance. Calling it multiple
51+
times on the same engine will raise a ``RuntimeError`` to prevent
52+
doubled stat updates and corrupted running statistics.
53+
54+
.. note::
55+
If multiple handlers are registered to ``Events.ITERATION_STARTED``,
56+
execution order depends on registration order. Register
57+
``GradMonitor`` before any handler that reads ``unhealthy_spike``
58+
to ensure the flag is fresh when read.
59+
60+
Args:
61+
model: The model whose gradient norms will be monitored.
62+
k: Multiplier for standard deviation when using the default spike
63+
detector. Ignored if ``spike_detector`` is provided. Default: 3.0.
64+
scaler: Optional ``torch.cuda.amp.GradScaler`` instance for AMP
65+
workflows. When provided, the raw gradient norm is divided by
66+
the current scale factor to recover the true unscaled norm.
67+
spike_detector: Optional callable with signature
68+
``(mean, m2, count, norm, k) -> bool``. If provided, replaces
69+
the default ``mean + k * std`` rule entirely. Use this to
70+
implement custom thresholding logic.
71+
72+
Example:
73+
.. code-block:: python
74+
75+
import torch
76+
from ignite.engine import Engine, Events
77+
from ignite.handlers import GradMonitor
78+
79+
model = torch.nn.Linear(10, 1)
80+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
81+
82+
def train_step(engine, batch):
83+
# Check the spike flag set by GradMonitor from the previous
84+
# iteration. On iteration 1 this will always be False
85+
# (not enough history yet).
86+
if engine.state.unhealthy_spike:
87+
# Discard this batch entirely.
88+
# Do NOT zero grads here so GradMonitor can still
89+
# read them on the next iteration.
90+
return {"loss": None, "skipped": True}
91+
92+
optimizer.zero_grad() # zero at START, not end
93+
x, y = batch
94+
loss = ((model(x) - y) ** 2).mean()
95+
loss.backward()
96+
optimizer.step()
97+
# Do NOT call optimizer.zero_grad() here.
98+
return {"loss": loss.item(), "skipped": False}
99+
100+
trainer = Engine(train_step)
101+
102+
# Attach GradMonitor BEFORE any handler that reads unhealthy_spike
103+
# to guarantee correct ordering on Events.ITERATION_STARTED.
104+
monitor = GradMonitor(model, k=3.0)
105+
monitor.attach(trainer)
106+
107+
# This handler runs after GradMonitor because it is registered after.
108+
@trainer.on(Events.ITERATION_STARTED)
109+
def log_spike(engine):
110+
if engine.state.unhealthy_spike:
111+
print(f"Spike at iteration {engine.state.iteration}!")
112+
113+
Example with AMP:
114+
.. code-block:: python
115+
116+
scaler = torch.cuda.amp.GradScaler()
117+
monitor = GradMonitor(model, k=3.0, scaler=scaler)
118+
monitor.attach(trainer)
119+
120+
Example with a custom spike detector:
121+
.. code-block:: python
122+
123+
def my_detector(mean, m2, count, norm, k):
124+
# Flag only if norm exceeds an absolute limit of 100.
125+
return norm > 100.0
126+
127+
monitor = GradMonitor(model, spike_detector=my_detector)
128+
monitor.attach(trainer)
129+
130+
.. versionadded:: 0.6.0
131+
"""
132+
133+
def __init__(
134+
self,
135+
model: torch.nn.Module,
136+
k: float = 3.0,
137+
scaler: Optional[Any] = None,
138+
spike_detector: Optional[Callable] = None,
139+
):
140+
if not isinstance(model, torch.nn.Module):
141+
raise TypeError(
142+
f"model must be a torch.nn.Module, got {type(model)}"
143+
)
144+
if not isinstance(k, (int, float)):
145+
raise TypeError(
146+
f"k must be a numeric value, got {type(k)}"
147+
)
148+
if k <= 0:
149+
raise ValueError(
150+
f"k must be a positive number, got {k}"
151+
)
152+
153+
self.model = model
154+
self.k = k
155+
self.scaler = scaler
156+
self.spike_detector = spike_detector if spike_detector is not None else _default_spike_detector
157+
self._device: Optional[torch.device] = None
158+
self._attached: bool = False
159+
self.count = 0
160+
self.mean = 0.0
161+
self.m2 = 0.0
162+
163+
def _get_device(self) -> torch.device:
164+
if self._device is None:
165+
try:
166+
self._device = next(self.model.parameters()).device
167+
except StopIteration:
168+
self._device = torch.device("cpu")
169+
return self._device
170+
171+
def _compute_grad_norm(self) -> float:
172+
"""
173+
Computes the global L2 gradient norm across all model parameters.
174+
175+
Handles:
176+
- AMP: divides by GradScaler scale if a scaler was provided.
177+
- DDP: sums squared norms across all distributed processes.
178+
179+
Returns:
180+
The L2 gradient norm as a Python float.
181+
"""
182+
device = self._get_device()
183+
total_norm_sq = torch.tensor(0.0, device=device)
184+
185+
for p in self.model.parameters():
186+
if p.grad is not None:
187+
total_norm_sq += p.grad.pow(2).sum()
188+
189+
# DDP support: sum squared norms across all processes so every.
190+
# node sees the same global norm.
191+
try:
192+
from ignite.distributed import get_world_size, all_reduce
193+
if get_world_size() > 1:
194+
total_norm_sq = all_reduce(total_norm_sq)
195+
except ImportError:
196+
pass
197+
198+
total_norm: float = torch.sqrt(total_norm_sq).item()
199+
200+
# AMP support: divide by the scaler's scale factor to recover.
201+
# the true unscaled gradient norm.
202+
if self.scaler is not None:
203+
scale = self.scaler.get_scale()
204+
if scale > 0:
205+
total_norm /= scale
206+
207+
return total_norm
208+
209+
def _update_stats(self, norm: float) -> None:
210+
"""
211+
Update running mean and variance using Welford's online algorithm.
212+
O(1) memory, numerically stable.
213+
214+
Args:
215+
norm: The gradient norm from the current iteration.
216+
"""
217+
self.count += 1
218+
delta = norm - self.mean
219+
self.mean += delta / self.count
220+
delta2 = norm - self.mean
221+
self.m2 += delta * delta2
222+
223+
def __call__(self, engine: Engine) -> None:
224+
"""
225+
Called at every ``Events.ITERATION_STARTED``.
226+
Computes the gradient norm, evaluates the spike detector, sets
227+
``engine.state.unhealthy_spike``, then updates running statistics.
228+
229+
Args:
230+
engine: The Ignite training engine.
231+
"""
232+
norm = self._compute_grad_norm()
233+
234+
# Evaluate spike BEFORE updating stats so the current norm is compared against history from previous iterations only.
235+
engine.state.unhealthy_spike = self.spike_detector(
236+
self.mean, self.m2, self.count, norm, self.k
237+
)
238+
239+
self._update_stats(norm)
240+
241+
def attach(self, engine: Engine) -> "GradMonitor":
242+
"""
243+
Attach this handler to an engine.
244+
245+
Initialises ``engine.state.unhealthy_spike = False`` at the start
246+
of each run so the flag is always safe to read inside the train step
247+
from the very first iteration.
248+
249+
Raises:
250+
RuntimeError: If this handler has already been attached to an engine.
251+
252+
Args:
253+
engine: The Ignite training engine.
254+
255+
Returns:
256+
self, to allow fluent chaining:
257+
``GradMonitor(model).attach(trainer)``.
258+
"""
259+
if self._attached:
260+
raise RuntimeError(
261+
"GradMonitor is already attached to an engine. "
262+
"Create a new GradMonitor instance to attach to a different engine."
263+
)
264+
self._attached = True
265+
266+
@engine.on(Events.STARTED)
267+
def _init_flag(e: Engine) -> None:
268+
e.state.unhealthy_spike = False
269+
270+
if hasattr(engine, "state_dict_user_keys"):
271+
engine.state_dict_user_keys.append("unhealthy_spike")
272+
273+
engine.add_event_handler(Events.ITERATION_STARTED, self)
274+
return self

0 commit comments

Comments
 (0)