Skip to content

Commit d6d546a

Browse files
yurekamiclaude
andauthored
[doc] feat: improve docstrings in torch_functional.py (verl-project#1345) (verl-project#4730)
## Summary This PR improves the documentation quality of `verl/utils/torch_functional.py` by adding comprehensive Google-style docstrings to core utility functions. ### Functions Improved | Function | Change | |----------|--------| | `gather_from_labels` | Complete docstring with example usage | | `logprobs_from_logits_flash_attn` | Added docstring explaining Flash Attention usage | | `logprobs_from_logits_torch_npu` | Added docstring for NPU implementation | | `logprobs_from_logits_naive` | Added docstring for standard implementation | | `logprobs_from_logits_v2` | Enhanced docstring explaining memory efficiency | | `clip_by_value` | Complete docstring with See Also reference | | `entropy_from_logits` | Added docstring with mathematical formula | | `entropy_from_logits_with_chunking` | Added docstring explaining chunking strategy | | `masked_sum` | Complete docstring with type hints | | `compute_grad_norm` | Added docstring clarifying squared norm return | | `broadcast_dict_tensor` | Added docstring with optimization note | | `allgather_dict_tensors` | Complete docstring with return type | ### Improvements Include - Proper type hints on function signatures - Detailed `Args` and `Returns` sections - `Note` sections for important implementation details - Mathematical formulas where applicable (e.g., entropy calculation) - Usage examples for complex functions ## Test plan - [x] Python syntax verification passed - [ ] CI tests should pass (documentation-only changes) Contributes to verl-project#1345 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: yurekami <yurekami@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 7d0e8dd commit d6d546a

File tree

1 file changed

+190
-34
lines changed

1 file changed

+190
-34
lines changed

verl/utils/torch_functional.py

Lines changed: 190 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,25 @@
4646
NPU_CROSS_ENTROPY_LOSS_AVAILABLE = False
4747

4848

49-
def gather_from_labels(data, label):
50-
"""Gather the label from data. The value in label should be [0, vocab_size)
49+
def gather_from_labels(data: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
50+
"""Gather values from data tensor at positions specified by label indices.
51+
52+
Selects elements from the last dimension of `data` based on indices in `label`.
53+
Commonly used to extract log-probabilities for specific token IDs from a
54+
vocabulary distribution.
5155
5256
Args:
53-
data: (..., vocab_size)
54-
label (torch.IntTensor) : (...,)
57+
data: Input tensor of shape (..., vocab_size) containing values to gather from.
58+
label: Index tensor of shape (...,) with values in range [0, vocab_size).
5559
5660
Returns:
61+
torch.Tensor: Gathered values with shape (...,), same as label shape.
5762
63+
Example:
64+
>>> logits = torch.randn(2, 3, 100) # [batch, seq, vocab]
65+
>>> labels = torch.randint(0, 100, (2, 3)) # [batch, seq]
66+
>>> gathered = gather_from_labels(logits, labels) # [batch, seq]
5867
"""
59-
6068
output = torch.gather(data, -1, label.unsqueeze(-1)).squeeze(-1)
6169
return output
6270

@@ -92,30 +100,89 @@ def logprobs_from_logits(logits, labels, inplace_backward=True):
92100
return output
93101

94102

95-
def logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True):
103+
def logprobs_from_logits_flash_attn(
104+
logits: torch.Tensor, labels: torch.Tensor, inplace_backward: bool = True
105+
) -> torch.Tensor:
106+
"""Compute log-probabilities using Flash Attention's optimized cross-entropy.
107+
108+
Uses the Flash Attention library's Triton-based cross-entropy implementation
109+
for efficient computation on NVIDIA GPUs.
110+
111+
Args:
112+
logits: Model output logits of shape (batch_size, vocab_size).
113+
labels: Target token indices of shape (batch_size,).
114+
inplace_backward: If True, perform backward pass in-place for memory efficiency.
115+
116+
Returns:
117+
torch.Tensor: Log-probabilities for target labels, shape (batch_size,).
118+
119+
Raises:
120+
AssertionError: If flash-attn version < 2.4.3 (different return format).
121+
"""
96122
output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward)
97123
assert isinstance(output, tuple), (
98124
"please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]."
99125
)
100126
return -output[0]
101127

102128

103-
def logprobs_from_logits_torch_npu(logits, labels):
129+
def logprobs_from_logits_torch_npu(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
130+
"""Compute log-probabilities using Ascend NPU's optimized cross-entropy.
131+
132+
Uses torch_npu's native cross-entropy implementation for efficient
133+
computation on Huawei Ascend NPU devices.
134+
135+
Args:
136+
logits: Model output logits of shape (..., vocab_size).
137+
labels: Target token indices of shape (...,).
138+
139+
Returns:
140+
torch.Tensor: Log-probabilities for target labels, same shape as labels.
141+
"""
104142
batch_dim = logits.shape[:-1]
105143
logits = logits.reshape(-1, logits.shape[-1])
106144
loss, _, _, _ = torch_npu.npu_cross_entropy_loss(logits, labels.reshape(-1), reduction="none")
107145
return -loss.view(*batch_dim)
108146

109147

110-
def logprobs_from_logits_naive(logits, labels):
148+
def logprobs_from_logits_naive(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
149+
"""Compute log-probabilities using standard log-softmax approach.
150+
151+
Simple implementation using PyTorch's log_softmax followed by gathering.
152+
Less memory-efficient than specialized implementations but works on all devices.
153+
154+
Args:
155+
logits: Model output logits of shape (..., vocab_size).
156+
labels: Target token indices of shape (...,).
157+
158+
Returns:
159+
torch.Tensor: Log-probabilities for target labels, same shape as labels.
160+
"""
111161
logp = F.log_softmax(logits, dim=-1)
112162
logpy = gather_from_labels(logp, labels)
113163
return logpy
114164

115165

116-
def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
117-
"""
118-
A memory efficient implementation of logprobs_from_logits
166+
def logprobs_from_logits_v2(logits: torch.FloatTensor, labels: torch.Tensor) -> torch.Tensor:
167+
"""Memory-efficient log-probability computation using row-wise processing.
168+
169+
Computes log-probabilities by processing one row at a time to reduce peak
170+
memory consumption. Uses logsumexp for float32/float64, falls back to
171+
log_softmax for bfloat16 due to numerical stability concerns.
172+
173+
The mathematical identity used is: log_softmax(x_i) = x_i - logsumexp(x)
174+
175+
Args:
176+
logits: Model output logits of shape (batch_size, seq_len, vocab_size)
177+
or (batch_size, vocab_size).
178+
labels: Target token indices matching logits shape without vocab dimension.
179+
180+
Returns:
181+
torch.Tensor: Log-probabilities for target labels.
182+
183+
Note:
184+
This implementation trades compute for memory by iterating over batch
185+
dimension, making it suitable for large vocabulary sizes.
119186
"""
120187
if logits.dtype in [torch.float32, torch.float64]:
121188
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
@@ -133,24 +200,62 @@ def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
133200
return logprobs_labels
134201

135202

136-
def clip_by_value(x, tensor_min, tensor_max):
137-
"""
138-
Tensor extenstion to torch.clamp
139-
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
203+
def clip_by_value(
204+
x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor
205+
) -> torch.Tensor:
206+
"""Clip tensor values to a range defined by tensor bounds.
207+
208+
Extension of torch.clamp that supports tensor-valued min/max bounds
209+
instead of only scalar bounds.
210+
211+
Args:
212+
x: Input tensor to clip.
213+
tensor_min: Minimum bound tensor (broadcastable to x).
214+
tensor_max: Maximum bound tensor (broadcastable to x).
215+
216+
Returns:
217+
torch.Tensor: Clipped tensor with values in [tensor_min, tensor_max].
218+
219+
See Also:
220+
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
140221
"""
141222
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
142223
return clipped
143224

144225

145-
def entropy_from_logits(logits: torch.Tensor):
146-
"""Calculate entropy from logits."""
226+
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
227+
"""Calculate Shannon entropy from unnormalized logits.
228+
229+
Computes H(p) = -sum(p * log(p)) using the numerically stable formula:
230+
entropy = logsumexp(logits) - sum(softmax(logits) * logits)
231+
232+
Args:
233+
logits: Unnormalized log-probabilities of shape (..., vocab_size).
234+
235+
Returns:
236+
torch.Tensor: Entropy values with shape (...,), one per distribution.
237+
"""
147238
pd = torch.nn.functional.softmax(logits, dim=-1)
148239
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
149240
return entropy
150241

151242

152-
def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 2048):
153-
"""Memory-efficient entropy calculation with chunking."""
243+
def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 2048) -> torch.Tensor:
244+
"""Memory-efficient entropy calculation using chunked processing.
245+
246+
Computes entropy by processing the batch in chunks to reduce peak memory
247+
usage. Useful for large batch sizes or when memory is constrained.
248+
249+
Args:
250+
logits: Unnormalized log-probabilities of shape (batch_size, vocab_size).
251+
chunk_size: Number of samples to process at once. Defaults to 2048.
252+
253+
Returns:
254+
torch.Tensor: Entropy values with shape (batch_size,).
255+
256+
Note:
257+
Converts chunks to float32 for numerical stability during computation.
258+
"""
154259
entropy = torch.zeros(logits.shape[0], device=logits.device)
155260
for i in range(0, logits.shape[0], chunk_size):
156261
logits_chunk = logits[i : i + chunk_size].float()
@@ -160,8 +265,23 @@ def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 20
160265
return entropy
161266

162267

163-
def masked_sum(values, mask, axis=None):
164-
"""Compute mean of tensor with a masked values."""
268+
def masked_sum(
269+
values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None
270+
) -> torch.Tensor:
271+
"""Compute sum of tensor values where mask is True.
272+
273+
NaN values outside the mask are replaced with zeros to prevent
274+
contaminating the sum.
275+
276+
Args:
277+
values: Input tensor containing values to sum.
278+
mask: Boolean or numeric mask tensor (same shape as values).
279+
Non-zero values indicate elements to include.
280+
axis: Dimension(s) along which to sum. None sums all elements.
281+
282+
Returns:
283+
torch.Tensor: Sum of masked values, reduced along specified axis.
284+
"""
165285
# If NaNs exist out of mask, replace NaNs in values with a value that
166286
# won't affect the sum (e.g., 0 for masked regions)
167287
valid_values = torch.where(mask.bool(), values, 0.0)
@@ -246,35 +366,71 @@ def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2,
246366
return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype)
247367

248368

249-
def compute_grad_norm(model: nn.Module):
369+
def compute_grad_norm(model: nn.Module) -> float:
370+
"""Compute the squared L2 norm of all gradients in a model.
371+
372+
Sums the squared values of all gradient tensors across all parameters.
373+
Useful for monitoring gradient magnitudes during training.
374+
375+
Args:
376+
model: PyTorch model with computed gradients.
377+
378+
Returns:
379+
float: Sum of squared gradient values (not the square root).
380+
381+
Note:
382+
Returns the squared norm, not the norm itself. To get the actual
383+
L2 norm, take the square root of the returned value.
384+
"""
250385
total_grad_square = 0
251386
for param in model.parameters():
252387
if param.grad is not None:
253388
total_grad_square += torch.sum(torch.square(param.grad.detach())).item()
254389
return total_grad_square
255390

256391

257-
def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src, group):
258-
"""
259-
TODO: optimize this. Technically, we only need one broadcast
260-
"""
392+
def broadcast_dict_tensor(
393+
tensors: dict[str, torch.Tensor] | TensorDict, src: int, group
394+
) -> None:
395+
"""Broadcast all tensors in a dictionary from source rank to all ranks.
396+
397+
Iterates over all tensors in the dictionary and broadcasts each one
398+
from the source rank to all other ranks in the process group.
399+
400+
Args:
401+
tensors: Dictionary or TensorDict containing tensors to broadcast.
402+
src: Source rank from which to broadcast.
403+
group: Process group for the broadcast operation.
261404
405+
Note:
406+
This implementation broadcasts tensors one at a time. Could be optimized
407+
to use a single broadcast with packed tensors.
408+
"""
262409
for key in tensors.sorted_keys:
263410
torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False)
264411

265412

266-
def allgather_dict_tensors(tensors: dict[str, torch.Tensor] | TensorDict, size, group, dim=0):
267-
"""
268-
TODO: optimize this.
269-
- We can use async ops
270-
- We can use only one allgather
413+
def allgather_dict_tensors(
414+
tensors: dict[str, torch.Tensor] | TensorDict, size: int, group, dim: int = 0
415+
) -> dict[str, torch.Tensor] | TensorDict:
416+
"""Gather tensors from all ranks and concatenate them.
417+
418+
Performs all_gather on each tensor in the dictionary and concatenates
419+
the results along the specified dimension.
420+
271421
Args:
272-
tensors:
273-
size:
274-
group:
422+
tensors: Dictionary or TensorDict containing tensors to gather.
423+
size: Number of ranks in the process group.
424+
group: Process group for the all_gather operation.
425+
dim: Dimension along which to concatenate gathered tensors. Defaults to 0.
275426
276427
Returns:
428+
Dictionary or TensorDict (matching input type) with gathered and
429+
concatenated tensors. Each tensor's size along `dim` is multiplied by `size`.
277430
431+
Note:
432+
This implementation gathers tensors one at a time synchronously.
433+
Could be optimized using async ops or packed all_gather.
278434
"""
279435
if isinstance(tensors, TensorDict):
280436
is_tensor_dict = True

0 commit comments

Comments
 (0)