1+ import torch
2+ from typing import Any , Callable , Optional
3+ from ignite .engine import Engine , Events
4+
5+
6+ def _default_spike_detector (mean : float , m2 : float , count : int , norm : float , k : float ) -> bool :
7+ """
8+ Default spike detection rule: flags if norm > mean + k * std.
9+
10+ Args:
11+ mean: Running mean of gradient norms.
12+ m2: Running sum of squared deviations (Welford's algorithm).
13+ count: Number of iterations seen so far.
14+ norm: Current gradient norm.
15+ k: Standard deviation multiplier.
16+
17+ Returns:
18+ True if the current norm is a spike, False otherwise.
19+ """
20+ if count < 2 :
21+ return False
22+ std = (m2 / (count - 1 )) ** 0.5
23+ return norm > mean + k * std + 1e-6
24+
25+
26+ class GradMonitor :
27+ """
28+ Monitors the L2 gradient norm each iteration to detect training instability
29+ (e.g. exploding gradients, entropy collapse) using a dynamic threshold.
30+
31+ The handler attaches to ``Events.ITERATION_STARTED``, meaning it reads
32+ gradients computed during the **previous** iteration. The spike flag is
33+ set on ``engine.state.unhealthy_spike`` and can be checked at the start
34+ of the next iteration's train step.
35+
36+ .. warning::
37+ This handler requires that gradients are **not zeroed** at the end of
38+ your train step. If you call ``optimizer.zero_grad()`` at the end of
39+ ``train_step``, all gradients will be gone by the time this handler
40+ runs and the norm will always be 0. Instead, call
41+ ``optimizer.zero_grad()`` at the **start** of your train step, after
42+ checking the spike flag.
43+
44+ .. warning::
45+ The ``unhealthy_spike`` flag on ``engine.state`` reflects gradients
46+ from the **previous** iteration, not the current one. Design your
47+ train step accordingly.
48+
49+ .. warning::
50+ Call ``attach`` only once per engine instance. Calling it multiple
51+ times on the same engine will raise a ``RuntimeError`` to prevent
52+ doubled stat updates and corrupted running statistics.
53+
54+ .. note::
55+ If multiple handlers are registered to ``Events.ITERATION_STARTED``,
56+ execution order depends on registration order. Register
57+ ``GradMonitor`` before any handler that reads ``unhealthy_spike``
58+ to ensure the flag is fresh when read.
59+
60+ Args:
61+ model: The model whose gradient norms will be monitored.
62+ k: Multiplier for standard deviation when using the default spike
63+ detector. Ignored if ``spike_detector`` is provided. Default: 3.0.
64+ scaler: Optional ``torch.cuda.amp.GradScaler`` instance for AMP
65+ workflows. When provided, the raw gradient norm is divided by
66+ the current scale factor to recover the true unscaled norm.
67+ spike_detector: Optional callable with signature
68+ ``(mean, m2, count, norm, k) -> bool``. If provided, replaces
69+ the default ``mean + k * std`` rule entirely. Use this to
70+ implement custom thresholding logic.
71+
72+ Example:
73+ .. code-block:: python
74+
75+ import torch
76+ from ignite.engine import Engine, Events
77+ from ignite.handlers import GradMonitor
78+
79+ model = torch.nn.Linear(10, 1)
80+ optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
81+
82+ def train_step(engine, batch):
83+ # Check the spike flag set by GradMonitor from the previous
84+ # iteration. On iteration 1 this will always be False
85+ # (not enough history yet).
86+ if engine.state.unhealthy_spike:
87+ # Discard this batch entirely.
88+ # Do NOT zero grads here so GradMonitor can still
89+ # read them on the next iteration.
90+ return {"loss": None, "skipped": True}
91+
92+ optimizer.zero_grad() # zero at START, not end
93+ x, y = batch
94+ loss = ((model(x) - y) ** 2).mean()
95+ loss.backward()
96+ optimizer.step()
97+ # Do NOT call optimizer.zero_grad() here.
98+ return {"loss": loss.item(), "skipped": False}
99+
100+ trainer = Engine(train_step)
101+
102+ # Attach GradMonitor BEFORE any handler that reads unhealthy_spike
103+ # to guarantee correct ordering on Events.ITERATION_STARTED.
104+ monitor = GradMonitor(model, k=3.0)
105+ monitor.attach(trainer)
106+
107+ # This handler runs after GradMonitor because it is registered after.
108+ @trainer.on(Events.ITERATION_STARTED)
109+ def log_spike(engine):
110+ if engine.state.unhealthy_spike:
111+ print(f"Spike at iteration {engine.state.iteration}!")
112+
113+ Example with AMP:
114+ .. code-block:: python
115+
116+ scaler = torch.cuda.amp.GradScaler()
117+ monitor = GradMonitor(model, k=3.0, scaler=scaler)
118+ monitor.attach(trainer)
119+
120+ Example with a custom spike detector:
121+ .. code-block:: python
122+
123+ def my_detector(mean, m2, count, norm, k):
124+ # Flag only if norm exceeds an absolute limit of 100.
125+ return norm > 100.0
126+
127+ monitor = GradMonitor(model, spike_detector=my_detector)
128+ monitor.attach(trainer)
129+
130+ .. versionadded:: 0.6.0
131+ """
132+
133+ def __init__ (
134+ self ,
135+ model : torch .nn .Module ,
136+ k : float = 3.0 ,
137+ scaler : Optional [Any ] = None ,
138+ spike_detector : Optional [Callable ] = None ,
139+ ):
140+ if not isinstance (model , torch .nn .Module ):
141+ raise TypeError (
142+ f"model must be a torch.nn.Module, got { type (model )} "
143+ )
144+ if not isinstance (k , (int , float )):
145+ raise TypeError (
146+ f"k must be a numeric value, got { type (k )} "
147+ )
148+ if k <= 0 :
149+ raise ValueError (
150+ f"k must be a positive number, got { k } "
151+ )
152+
153+ self .model = model
154+ self .k = k
155+ self .scaler = scaler
156+ self .spike_detector = spike_detector if spike_detector is not None else _default_spike_detector
157+ self ._device : Optional [torch .device ] = None
158+ self ._attached : bool = False
159+ self .count = 0
160+ self .mean = 0.0
161+ self .m2 = 0.0
162+
163+ def _get_device (self ) -> torch .device :
164+ if self ._device is None :
165+ try :
166+ self ._device = next (self .model .parameters ()).device
167+ except StopIteration :
168+ self ._device = torch .device ("cpu" )
169+ return self ._device
170+
171+ def _compute_grad_norm (self ) -> float :
172+ """
173+ Computes the global L2 gradient norm across all model parameters.
174+
175+ Handles:
176+ - AMP: divides by GradScaler scale if a scaler was provided.
177+ - DDP: sums squared norms across all distributed processes.
178+
179+ Returns:
180+ The L2 gradient norm as a Python float.
181+ """
182+ device = self ._get_device ()
183+ total_norm_sq = torch .tensor (0.0 , device = device )
184+
185+ for p in self .model .parameters ():
186+ if p .grad is not None :
187+ total_norm_sq += p .grad .pow (2 ).sum ()
188+
189+ # DDP support: sum squared norms across all processes so every.
190+ # node sees the same global norm.
191+ try :
192+ from ignite .distributed import get_world_size , all_reduce
193+ if get_world_size () > 1 :
194+ total_norm_sq = all_reduce (total_norm_sq )
195+ except ImportError :
196+ pass
197+
198+ total_norm : float = torch .sqrt (total_norm_sq ).item ()
199+
200+ # AMP support: divide by the scaler's scale factor to recover.
201+ # the true unscaled gradient norm.
202+ if self .scaler is not None :
203+ scale = self .scaler .get_scale ()
204+ if scale > 0 :
205+ total_norm /= scale
206+
207+ return total_norm
208+
209+ def _update_stats (self , norm : float ) -> None :
210+ """
211+ Update running mean and variance using Welford's online algorithm.
212+ O(1) memory, numerically stable.
213+
214+ Args:
215+ norm: The gradient norm from the current iteration.
216+ """
217+ self .count += 1
218+ delta = norm - self .mean
219+ self .mean += delta / self .count
220+ delta2 = norm - self .mean
221+ self .m2 += delta * delta2
222+
223+ def __call__ (self , engine : Engine ) -> None :
224+ """
225+ Called at every ``Events.ITERATION_STARTED``.
226+ Computes the gradient norm, evaluates the spike detector, sets
227+ ``engine.state.unhealthy_spike``, then updates running statistics.
228+
229+ Args:
230+ engine: The Ignite training engine.
231+ """
232+ norm = self ._compute_grad_norm ()
233+
234+ # Evaluate spike BEFORE updating stats so the current norm is compared against history from previous iterations only.
235+ engine .state .unhealthy_spike = self .spike_detector (
236+ self .mean , self .m2 , self .count , norm , self .k
237+ )
238+
239+ self ._update_stats (norm )
240+
241+ def attach (self , engine : Engine ) -> "GradMonitor" :
242+ """
243+ Attach this handler to an engine.
244+
245+ Initialises ``engine.state.unhealthy_spike = False`` at the start
246+ of each run so the flag is always safe to read inside the train step
247+ from the very first iteration.
248+
249+ Raises:
250+ RuntimeError: If this handler has already been attached to an engine.
251+
252+ Args:
253+ engine: The Ignite training engine.
254+
255+ Returns:
256+ self, to allow fluent chaining:
257+ ``GradMonitor(model).attach(trainer)``.
258+ """
259+ if self ._attached :
260+ raise RuntimeError (
261+ "GradMonitor is already attached to an engine. "
262+ "Create a new GradMonitor instance to attach to a different engine."
263+ )
264+ self ._attached = True
265+
266+ @engine .on (Events .STARTED )
267+ def _init_flag (e : Engine ) -> None :
268+ e .state .unhealthy_spike = False
269+
270+ if hasattr (engine , "state_dict_user_keys" ):
271+ engine .state_dict_user_keys .append ("unhealthy_spike" )
272+
273+ engine .add_event_handler (Events .ITERATION_STARTED , self )
274+ return self
0 commit comments