4646 NPU_CROSS_ENTROPY_LOSS_AVAILABLE = False
4747
4848
49- def gather_from_labels (data , label ):
50- """Gather the label from data. The value in label should be [0, vocab_size)
49+ def gather_from_labels (data : torch .Tensor , label : torch .Tensor ) -> torch .Tensor :
50+ """Gather values from data tensor at positions specified by label indices.
51+
52+ Selects elements from the last dimension of `data` based on indices in `label`.
53+ Commonly used to extract log-probabilities for specific token IDs from a
54+ vocabulary distribution.
5155
5256 Args:
53- data: (..., vocab_size)
54- label (torch.IntTensor) : (...,)
57+ data: Input tensor of shape (..., vocab_size) containing values to gather from.
58+ label: Index tensor of shape (...,) with values in range [0, vocab_size).
5559
5660 Returns:
61+ torch.Tensor: Gathered values with shape (...,), same as label shape.
5762
63+ Example:
64+ >>> logits = torch.randn(2, 3, 100) # [batch, seq, vocab]
65+ >>> labels = torch.randint(0, 100, (2, 3)) # [batch, seq]
66+ >>> gathered = gather_from_labels(logits, labels) # [batch, seq]
5867 """
59-
6068 output = torch .gather (data , - 1 , label .unsqueeze (- 1 )).squeeze (- 1 )
6169 return output
6270
@@ -92,30 +100,89 @@ def logprobs_from_logits(logits, labels, inplace_backward=True):
92100 return output
93101
94102
95- def logprobs_from_logits_flash_attn (logits , labels , inplace_backward = True ):
103+ def logprobs_from_logits_flash_attn (
104+ logits : torch .Tensor , labels : torch .Tensor , inplace_backward : bool = True
105+ ) -> torch .Tensor :
106+ """Compute log-probabilities using Flash Attention's optimized cross-entropy.
107+
108+ Uses the Flash Attention library's Triton-based cross-entropy implementation
109+ for efficient computation on NVIDIA GPUs.
110+
111+ Args:
112+ logits: Model output logits of shape (batch_size, vocab_size).
113+ labels: Target token indices of shape (batch_size,).
114+ inplace_backward: If True, perform backward pass in-place for memory efficiency.
115+
116+ Returns:
117+ torch.Tensor: Log-probabilities for target labels, shape (batch_size,).
118+
119+ Raises:
120+ AssertionError: If flash-attn version < 2.4.3 (different return format).
121+ """
96122 output = cross_entropy_loss (logits , labels , inplace_backward = inplace_backward )
97123 assert isinstance (output , tuple ), (
98124 "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]."
99125 )
100126 return - output [0 ]
101127
102128
103- def logprobs_from_logits_torch_npu (logits , labels ):
129+ def logprobs_from_logits_torch_npu (logits : torch .Tensor , labels : torch .Tensor ) -> torch .Tensor :
130+ """Compute log-probabilities using Ascend NPU's optimized cross-entropy.
131+
132+ Uses torch_npu's native cross-entropy implementation for efficient
133+ computation on Huawei Ascend NPU devices.
134+
135+ Args:
136+ logits: Model output logits of shape (..., vocab_size).
137+ labels: Target token indices of shape (...,).
138+
139+ Returns:
140+ torch.Tensor: Log-probabilities for target labels, same shape as labels.
141+ """
104142 batch_dim = logits .shape [:- 1 ]
105143 logits = logits .reshape (- 1 , logits .shape [- 1 ])
106144 loss , _ , _ , _ = torch_npu .npu_cross_entropy_loss (logits , labels .reshape (- 1 ), reduction = "none" )
107145 return - loss .view (* batch_dim )
108146
109147
110- def logprobs_from_logits_naive (logits , labels ):
148+ def logprobs_from_logits_naive (logits : torch .Tensor , labels : torch .Tensor ) -> torch .Tensor :
149+ """Compute log-probabilities using standard log-softmax approach.
150+
151+ Simple implementation using PyTorch's log_softmax followed by gathering.
152+ Less memory-efficient than specialized implementations but works on all devices.
153+
154+ Args:
155+ logits: Model output logits of shape (..., vocab_size).
156+ labels: Target token indices of shape (...,).
157+
158+ Returns:
159+ torch.Tensor: Log-probabilities for target labels, same shape as labels.
160+ """
111161 logp = F .log_softmax (logits , dim = - 1 )
112162 logpy = gather_from_labels (logp , labels )
113163 return logpy
114164
115165
116- def logprobs_from_logits_v2 (logits : torch .FloatTensor , labels ):
117- """
118- A memory efficient implementation of logprobs_from_logits
166+ def logprobs_from_logits_v2 (logits : torch .FloatTensor , labels : torch .Tensor ) -> torch .Tensor :
167+ """Memory-efficient log-probability computation using row-wise processing.
168+
169+ Computes log-probabilities by processing one row at a time to reduce peak
170+ memory consumption. Uses logsumexp for float32/float64, falls back to
171+ log_softmax for bfloat16 due to numerical stability concerns.
172+
173+ The mathematical identity used is: log_softmax(x_i) = x_i - logsumexp(x)
174+
175+ Args:
176+ logits: Model output logits of shape (batch_size, seq_len, vocab_size)
177+ or (batch_size, vocab_size).
178+ labels: Target token indices matching logits shape without vocab dimension.
179+
180+ Returns:
181+ torch.Tensor: Log-probabilities for target labels.
182+
183+ Note:
184+ This implementation trades compute for memory by iterating over batch
185+ dimension, making it suitable for large vocabulary sizes.
119186 """
120187 if logits .dtype in [torch .float32 , torch .float64 ]:
121188 logits_labels = torch .gather (logits , dim = - 1 , index = labels .unsqueeze (- 1 )).squeeze (- 1 )
@@ -133,24 +200,62 @@ def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
133200 return logprobs_labels
134201
135202
136- def clip_by_value (x , tensor_min , tensor_max ):
137- """
138- Tensor extenstion to torch.clamp
139- https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
203+ def clip_by_value (
204+ x : torch .Tensor , tensor_min : torch .Tensor , tensor_max : torch .Tensor
205+ ) -> torch .Tensor :
206+ """Clip tensor values to a range defined by tensor bounds.
207+
208+ Extension of torch.clamp that supports tensor-valued min/max bounds
209+ instead of only scalar bounds.
210+
211+ Args:
212+ x: Input tensor to clip.
213+ tensor_min: Minimum bound tensor (broadcastable to x).
214+ tensor_max: Maximum bound tensor (broadcastable to x).
215+
216+ Returns:
217+ torch.Tensor: Clipped tensor with values in [tensor_min, tensor_max].
218+
219+ See Also:
220+ https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
140221 """
141222 clipped = torch .max (torch .min (x , tensor_max ), tensor_min )
142223 return clipped
143224
144225
145- def entropy_from_logits (logits : torch .Tensor ):
146- """Calculate entropy from logits."""
226+ def entropy_from_logits (logits : torch .Tensor ) -> torch .Tensor :
227+ """Calculate Shannon entropy from unnormalized logits.
228+
229+ Computes H(p) = -sum(p * log(p)) using the numerically stable formula:
230+ entropy = logsumexp(logits) - sum(softmax(logits) * logits)
231+
232+ Args:
233+ logits: Unnormalized log-probabilities of shape (..., vocab_size).
234+
235+ Returns:
236+ torch.Tensor: Entropy values with shape (...,), one per distribution.
237+ """
147238 pd = torch .nn .functional .softmax (logits , dim = - 1 )
148239 entropy = torch .logsumexp (logits , dim = - 1 ) - torch .sum (pd * logits , dim = - 1 )
149240 return entropy
150241
151242
152- def entropy_from_logits_with_chunking (logits : torch .Tensor , chunk_size : int = 2048 ):
153- """Memory-efficient entropy calculation with chunking."""
243+ def entropy_from_logits_with_chunking (logits : torch .Tensor , chunk_size : int = 2048 ) -> torch .Tensor :
244+ """Memory-efficient entropy calculation using chunked processing.
245+
246+ Computes entropy by processing the batch in chunks to reduce peak memory
247+ usage. Useful for large batch sizes or when memory is constrained.
248+
249+ Args:
250+ logits: Unnormalized log-probabilities of shape (batch_size, vocab_size).
251+ chunk_size: Number of samples to process at once. Defaults to 2048.
252+
253+ Returns:
254+ torch.Tensor: Entropy values with shape (batch_size,).
255+
256+ Note:
257+ Converts chunks to float32 for numerical stability during computation.
258+ """
154259 entropy = torch .zeros (logits .shape [0 ], device = logits .device )
155260 for i in range (0 , logits .shape [0 ], chunk_size ):
156261 logits_chunk = logits [i : i + chunk_size ].float ()
@@ -160,8 +265,23 @@ def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 20
160265 return entropy
161266
162267
163- def masked_sum (values , mask , axis = None ):
164- """Compute mean of tensor with a masked values."""
268+ def masked_sum (
269+ values : torch .Tensor , mask : torch .Tensor , axis : int | tuple [int , ...] | None = None
270+ ) -> torch .Tensor :
271+ """Compute sum of tensor values where mask is True.
272+
273+ NaN values outside the mask are replaced with zeros to prevent
274+ contaminating the sum.
275+
276+ Args:
277+ values: Input tensor containing values to sum.
278+ mask: Boolean or numeric mask tensor (same shape as values).
279+ Non-zero values indicate elements to include.
280+ axis: Dimension(s) along which to sum. None sums all elements.
281+
282+ Returns:
283+ torch.Tensor: Sum of masked values, reduced along specified axis.
284+ """
165285 # If NaNs exist out of mask, replace NaNs in values with a value that
166286 # won't affect the sum (e.g., 0 for masked regions)
167287 valid_values = torch .where (mask .bool (), values , 0.0 )
@@ -246,35 +366,71 @@ def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2,
246366 return (eos_mask .cumsum (dim = 1 ) - eos_mask ).eq (0 ).to (dtype )
247367
248368
249- def compute_grad_norm (model : nn .Module ):
369+ def compute_grad_norm (model : nn .Module ) -> float :
370+ """Compute the squared L2 norm of all gradients in a model.
371+
372+ Sums the squared values of all gradient tensors across all parameters.
373+ Useful for monitoring gradient magnitudes during training.
374+
375+ Args:
376+ model: PyTorch model with computed gradients.
377+
378+ Returns:
379+ float: Sum of squared gradient values (not the square root).
380+
381+ Note:
382+ Returns the squared norm, not the norm itself. To get the actual
383+ L2 norm, take the square root of the returned value.
384+ """
250385 total_grad_square = 0
251386 for param in model .parameters ():
252387 if param .grad is not None :
253388 total_grad_square += torch .sum (torch .square (param .grad .detach ())).item ()
254389 return total_grad_square
255390
256391
257- def broadcast_dict_tensor (tensors : dict [str , torch .Tensor ] | TensorDict , src , group ):
258- """
259- TODO: optimize this. Technically, we only need one broadcast
260- """
392+ def broadcast_dict_tensor (
393+ tensors : dict [str , torch .Tensor ] | TensorDict , src : int , group
394+ ) -> None :
395+ """Broadcast all tensors in a dictionary from source rank to all ranks.
396+
397+ Iterates over all tensors in the dictionary and broadcasts each one
398+ from the source rank to all other ranks in the process group.
399+
400+ Args:
401+ tensors: Dictionary or TensorDict containing tensors to broadcast.
402+ src: Source rank from which to broadcast.
403+ group: Process group for the broadcast operation.
261404
405+ Note:
406+ This implementation broadcasts tensors one at a time. Could be optimized
407+ to use a single broadcast with packed tensors.
408+ """
262409 for key in tensors .sorted_keys :
263410 torch .distributed .broadcast (tensors [key ], src = src , group = group , async_op = False )
264411
265412
266- def allgather_dict_tensors (tensors : dict [str , torch .Tensor ] | TensorDict , size , group , dim = 0 ):
267- """
268- TODO: optimize this.
269- - We can use async ops
270- - We can use only one allgather
413+ def allgather_dict_tensors (
414+ tensors : dict [str , torch .Tensor ] | TensorDict , size : int , group , dim : int = 0
415+ ) -> dict [str , torch .Tensor ] | TensorDict :
416+ """Gather tensors from all ranks and concatenate them.
417+
418+ Performs all_gather on each tensor in the dictionary and concatenates
419+ the results along the specified dimension.
420+
271421 Args:
272- tensors:
273- size:
274- group:
422+ tensors: Dictionary or TensorDict containing tensors to gather.
423+ size: Number of ranks in the process group.
424+ group: Process group for the all_gather operation.
425+ dim: Dimension along which to concatenate gathered tensors. Defaults to 0.
275426
276427 Returns:
428+ Dictionary or TensorDict (matching input type) with gathered and
429+ concatenated tensors. Each tensor's size along `dim` is multiplied by `size`.
277430
431+ Note:
432+ This implementation gathers tensors one at a time synchronously.
433+ Could be optimized using async ops or packed all_gather.
278434 """
279435 if isinstance (tensors , TensorDict ):
280436 is_tensor_dict = True
0 commit comments