Skip to content
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 108 additions & 41 deletions src/transformers/model_debugging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import json
import os
import re
from contextlib import contextmanager
from contextlib import contextmanager, redirect_stdout
from io import StringIO
from typing import Optional

from transformers.utils.import_utils import export
Expand Down Expand Up @@ -87,20 +88,51 @@ def _serialize_io(value):

if hasattr(value, "_local_tensor"):
# DTensor-like handling, just use local tensor attribute
return {
torch.set_printoptions(sci_mode=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could ahve max line width increased as well1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's done in the repr_to_list method, will unify this ;)

val_repr = _repr_to_list(value)
out = {
"shape": repr(value._local_tensor.shape),
"dtype": repr(value._local_tensor.dtype),
"value": _sanitize_repr_for_diff(repr(value)),
"value": val_repr,
}
if value._local_tensor.dtype in {torch.float16, torch.float32, torch.bfloat16}:
value = value._local_tensor.copy()
out.update({
"mean": _sanitize_repr_for_diff(repr(value.mean())),
"std": _sanitize_repr_for_diff(repr(value.std())),
"min": _sanitize_repr_for_diff(repr(value.min())),
"max": _sanitize_repr_for_diff(repr(value.max())),
})
return out

if isinstance(value, torch.Tensor):
# standard PyTorch Tensor
# return also the shape of such
return {"shape": repr(value.shape), "dtype": repr(value.dtype), "value": _sanitize_repr_for_diff(repr(value))}
torch.set_printoptions(sci_mode=True)
val_repr = _repr_to_list(value)
out = {
"shape": repr(value.shape),
"dtype": repr(value.dtype),
"value": val_repr,
}
if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
out.update({
"mean": _sanitize_repr_for_diff(repr(value.mean())),
"std": _sanitize_repr_for_diff(repr(value.std())),
"min": _sanitize_repr_for_diff(repr(value.min())),
"max": _sanitize_repr_for_diff(repr(value.max())),
})
return out

# fallback for everything else (bool, int, float, None, or custom class)
return _sanitize_repr_for_diff(repr(value))

def _repr_to_list(val):
return _sanitize_repr_for_diff(repr(val)).splitlines()

def _repr_to_list(value: torch.Tensor):
torch.set_printoptions(sci_mode=True, linewidth=120)
with StringIO() as buf, redirect_stdout(buf):
print(value) # to redirected stdout to avoid line splits
raw = buf.getvalue()
return _sanitize_repr_for_diff(raw).splitlines()

def prune_outputs_if_children(node):
# if there are children, remove this node's "outputs"
Expand All @@ -111,27 +143,81 @@ def prune_outputs_if_children(node):
prune_outputs_if_children(child)


LAYER_SUFFIX_RE = re.compile(r"(.*)\.(\d+)$") # should be generic enough, ends with a number

def is_layer_block(node):
match = LAYER_SUFFIX_RE.match(node.get("module_path", ""))
if not match or not node.get("children"):
return False
number = match.group(2)
return any(f".{number}." in child.get("module_path", "") for child in node["children"])


def prune_intermediate_layers(node):
if not node.get("children"):
return

layer_blocks = [(i, child) for i, child in enumerate(node["children"]) if is_layer_block(child)]

if len(layer_blocks) > 2:
to_remove = [i for i, _ in layer_blocks[1:-1]]
node["children"] = [child for i, child in enumerate(node["children"]) if i not in to_remove]

for child in node["children"]:
prune_intermediate_layers(child)


def log_model_debug_trace(debug_path, model):
if debug_path:
try:
os.makedirs(debug_path, exist_ok=False)
output_path = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree.json")
os.makedirs(debug_path, exist_ok=True)
base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree")
except Exception as e:
raise ValueError(f"Unexpected or existing debug_path={debug_path}. {e}")
else:
output_path = model._debugger_module_dump_name + "_debug_tree.json"
logger.info(f"Writing model trace at {output_path}")
with open(output_path, "w") as outfile:
prune_outputs_if_children(model._call_tree)
json.dump(model._call_tree, outfile, indent=2)
base = model._debugger_module_dump_name + "_debug_tree"

logger.info(f"Writing model trace at {base}.json")
full_path = base + "_FULL_TENSORS.json"
summary_path = base + "_SUMMARY.json"

prune_outputs_if_children(model._call_tree)

with open(full_path, "w") as f:
json.dump(model._call_tree, f, indent=2)

# summary-only version for readability - traversing the tree again #TODO optimize?
def strip_values(node):
def clean(val):
if isinstance(val, dict):
val.pop("value", None)
for v in val.values():
clean(v)
elif isinstance(val, list):
for item in val:
clean(item)

clean(node.get("inputs", {}))
clean(node.get("outputs", {}))

for child in node.get("children", []):
strip_values(child)


tree_copy = json.loads(json.dumps(model._call_tree)) # deep copy
strip_values(tree_copy)

with open(summary_path, "w") as f:
json.dump(tree_copy, f, indent=2)




def _attach_debugger_logic(model, class_name, debug_path: str):
# Prepare data structures on the model object
model._call_tree = {"module_path": class_name, "inputs": None, "outputs": None, "children": []}
model._debugger_model_call_stack = []
model._debugger_module_dump_name = class_name # used for final JSON filename

def wrap_forward(module, full_path):
orig_forward = module.forward

Expand All @@ -147,7 +233,7 @@ def wrapped_forward(*inps, **kws):
"children": [],
}
model._debugger_model_call_stack.append(node)
with torch.inference_mode():
with torch.no_grad():
out = orig_forward(*inps, **kws)

if _is_rank_zero():
Expand Down Expand Up @@ -188,7 +274,6 @@ def top_wrapped_forward(*inps, **kws):
model._debugger_model_call_stack.append(top_node)

out = real_top_forward(*inps, **kws)

if _is_rank_zero() and model._debugger_model_call_stack:
top_node["outputs"] = _serialize_io(out)
finished = model._debugger_model_call_stack.pop()
Expand All @@ -198,40 +283,22 @@ def top_wrapped_forward(*inps, **kws):
# prune empty stuff for visibility
[model._call_tree.pop(k, None) for k in list(model._call_tree.keys()) if not model._call_tree[k]]

# prune layers that are not 0 or last
prune_intermediate_layers(model._call_tree)
# Write final JSON trace here
log_model_debug_trace(debug_path=debug_path, model=model)
return out

model.forward = top_wrapped_forward

# Final hook for writing JSON on forward-end
def final_hook(_, inputs, outputs):
if _is_rank_zero() and model._debugger_model_call_stack:
finished = model._debugger_model_call_stack.pop()
model._call_tree["inputs"] = finished["inputs"]
model._call_tree["outputs"] = finished["outputs"]
model._call_tree["children"] = finished["children"]

if _is_rank_zero():
log_model_debug_trace(debug_path=debug_path, model=model)

model.register_forward_hook(final_hook)
# Optionally also for a couple possible hooks that have specific names. It should be just one.
# This means modules that are not typically called "forward" within the model. But we should not need to recurse
# through them.
possible_model_calls = ["language_model", "model"]
for model_call in possible_model_calls:
this_model_call = getattr(model, model_call, None)
if this_model_call and isinstance(this_model_call, (nn.Module, PreTrainedModel)):
this_model_call.register_forward_hook(final_hook)
break # exit the loop after finding one (unsure, but should be just one call.)


@export(backends=("torch",))
def model_addition_debugger(cls):
"""
# Model addition debugger - a model adder tracer
This decorator is a power user tool intended for model adders.
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
To note, this decorator enforces `torch.inference_mode()`.
To note, this decorator enforces `torch.no_grad()`.
## Usage

add decorator to your model class
Expand Down Expand Up @@ -289,7 +356,7 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None):
# Model addition debugger - context manager for model adders
This context manager is a power user tool intended for model adders.
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
To note, this context manager enforces `torch.inference_mode()`.
To note, this context manager enforces `torch.no_grad()`.

## Usage

Expand Down
Loading