Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -856,9 +856,9 @@ def eval_val_sliding(
seq_len = args.train_seq_len
total_tokens = val_tokens.numel() - 1

# Build windows; skip any too short to score a full stride
# Build windows; include final partial window if it has at least 1 token
window_starts = [ws for ws in range(0, total_tokens, stride)
if min(ws + seq_len, total_tokens) - ws >= stride]
if min(ws + seq_len, total_tokens) - ws >= 1]
total_windows = len(window_starts)

# Distribute across ranks
Expand Down Expand Up @@ -899,7 +899,7 @@ def eval_val_sliding(

for i, ws in enumerate(batch_ws):
wlen = wlens[i]
s = 0 if ws == 0 else wlen - stride
s = 0 if ws == 0 else max(wlen - stride, 0)
scored_nll = nll[i, s:wlen].to(torch.float64)
loss_sum += scored_nll.sum()
token_count += float(wlen - s)
Expand Down