Adds use_repr to model_addition_debugger_context (#37984)

* Adds use_repr to model_addition_debugger_context

* Updating docs for use_repr option
This commit is contained in:
Ryan Mullins 2025-05-23 05:35:13 -04:00 committed by GitHub
parent 38f9c5b15b
commit 9eb0a37c9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 175 additions and 67 deletions

View File

@ -16,7 +16,8 @@ rendered properly in your Markdown viewer.
# Model debugging toolboxes
This page lists all the debugging and model adding tools used by the library, as well as the utility functions it provides for it.
This page lists all the debugging and model adding tools used by the library, as well as the utility functions it
provides for it.
Most of those are only useful if you are adding new models in the library.
@ -26,13 +27,14 @@ Most of those are only useful if you are adding new models in the library.
### Model addition debugger - context manager for model adders
This context manager is a power user tool intended for model adders.
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
To note, this context manager enforces `torch.no_grad()`.
This context manager is a power user tool intended for model adders. It tracks all forward calls within a model forward
and logs a slice of each input and output on a nested JSON. To note, this context manager enforces `torch.no_grad()`.
### Rationale
Because when porting models to transformers, even from python to python, model adders often have to do a lot of manual operations, involving saving and loading tensors, comparing dtypes, etc. This small tool can hopefully shave off some time.
When porting models to transformers, even from python to python, model adders often have to do a lot of manual
operations, involving saving and loading tensors, comparing dtypes, etc. This small tool can hopefully shave off some
time.
### Usage
@ -62,10 +64,10 @@ inputs = processor(text=prompt, images=random_image, return_tensors="pt")
# call forward method (not .generate!)
with model_addition_debugger_context(
model,
debug_path="optional_path_to_your_directory",
do_prune_layers=False # This will output ALL the layers of a model.
):
model,
debug_path="optional_path_to_your_directory",
do_prune_layers=False # This will output ALL the layers of a model.
):
output = model.forward(**inputs)
```
@ -73,8 +75,8 @@ with model_addition_debugger_context(
### Reading results
The debugger generates two files from the forward call, both with the same base name,
but ending either with `_SUMMARY.json` or with `_FULL_TENSORS.json`.
The debugger generates two files from the forward call, both with the same base name, but ending either with
`_SUMMARY.json` or with `_FULL_TENSORS.json`.
The first one will contain a summary of each module's _input_ and _output_ tensor values and shapes.
@ -142,8 +144,8 @@ The first one will contain a summary of each module's _input_ and _output_ tenso
{ ... and so on
```
The `_FULL_TENSORS.json` file will display a full view of all tensors, which is useful
for comparing two files.
The `_FULL_TENSORS.json` file will display a full view of all tensors, which is useful for comparing two files.
```json
"pixel_values": {
"shape": "torch.Size([1, 5, 576, 588])",
@ -196,9 +198,38 @@ for comparing two files.
},
```
#### Saving tensors to disk
Some model adders may benefit from logging full tensor values to disk to support, for example, numerical analysis
across implementations.
Set `use_repr=False` to write tensors to disk using [SafeTensors](https://huggingface.co/docs/safetensors/en/index).
```python
with model_addition_debugger_context(
model,
debug_path="optional_path_to_your_directory",
do_prune_layers=False,
use_repr=False, # Defaults to True
):
output = model.forward(**inputs)
```
When using `use_repr=False`, tensors are written to the same disk location as the `_SUMMARY.json` and
`_FULL_TENSORS.json` files. The `value` property of entries in the `_FULL_TENSORS.json` file will contain a relative
path reference to the associated `.safetensors` file. Each tensor is written to its own file as the `data` property of
the state dictionary. File names are constructed using the `module_path` as a prefix with a few possible postfixes that
are built recursively.
* Module inputs are denoted with the `_inputs` and outputs by `_outputs`.
* `list` and `tuple` instances, such as `args` or function return values, will be postfixed with `_{index}`.
* `dict` instances will be postfixed with `_{key}`.
### Comparing between implementations
Once the forward passes of two models have been traced by the debugger, one can compare the `json` output files. See below: we can see slight differences between these two implementations' key projection layer. Inputs are mostly identical, but not quite. Looking through the file differences makes it easier to pinpoint which layer is wrong.
Once the forward passes of two models have been traced by the debugger, one can compare the `json` output files. See
below: we can see slight differences between these two implementations' key projection layer. Inputs are mostly
identical, but not quite. Looking through the file differences makes it easier to pinpoint which layer is wrong.
![download-icon](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/files_difference_debugging.png)
@ -206,8 +237,13 @@ Once the forward passes of two models have been traced by the debugger, one can
### Limitations and scope
This feature will only work for torch-based models, and would require more work and case-by-case approach for say `jax`-based models that are usually compiled. Models relying heavily on external kernel calls may work, but trace will probably miss some things. Regardless, any python implementation that aims at mimicking another implementation can be traced once instead of reran N times with breakpoints.
This feature will only work for torch-based models, and would require more work and case-by-case approach for say
`jax`-based models that are usually compiled. Models relying heavily on external kernel calls may work, but trace will
probably miss some things. Regardless, any python implementation that aims at mimicking another implementation can be
traced once instead of reran N times with breakpoints.
If you pass `do_prune_layers=False` to your model debugger, ALL the layers will be outputted to `json`. Else, only the first and last layer will be shown. This is useful when some layers (typically cross-attention) appear only after N layers.
If you pass `do_prune_layers=False` to your model debugger, ALL the layers will be outputted to `json`. Else, only the
first and last layer will be shown. This is useful when some layers (typically cross-attention) appear only after N
layers.
[[autodoc]] model_addition_debugger_context

View File

@ -21,6 +21,8 @@ from contextlib import contextmanager, redirect_stdout
from io import StringIO
from typing import Optional
from safetensors.torch import save_file
from transformers.utils.import_utils import requires
from .utils import is_torch_available
@ -65,64 +67,94 @@ def _dtensor_repr(x):
return "DTensor(non-rank0)"
def _serialize_io(value):
def _serialize_tensor_like_io(
value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None
):
"""
Converts Tensors and DTensors to a JSON-serializable dictionary representation.
Args:
value: Any Python object, often including torch Tensors, lists, dicts, etc.
debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensor as the
`value` property in the asscoiated FULL_TENSORS.json file, or to store the full tensors in separate
SafeTensors file and store the relative path to that file in the `value` property in the dictionary.
path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
tensor value if `use_repr=False`.
Returns:
A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
"""
torch.set_printoptions(sci_mode=True)
if use_repr:
value_out = _repr_to_list(value)
elif path_to_value:
if not path_to_value.endswith(".safetensors"):
path_to_value += ".safetensors"
filepath = os.path.join(debug_path, path_to_value) if debug_path else path_to_value
save_file({"data": value.contiguous().detach().cpu()}, filepath)
value_out = f"./{path_to_value}"
else:
raise ValueError(f"{use_repr=} and {path_to_value=} cannot both be falsy.")
out = {
"shape": repr(value.shape),
"dtype": repr(value.dtype),
"value": value_out,
}
if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
out.update(
{
"mean": _sanitize_repr_for_diff(repr(value.mean())),
"std": _sanitize_repr_for_diff(repr(value.std())),
"min": _sanitize_repr_for_diff(repr(value.min())),
"max": _sanitize_repr_for_diff(repr(value.max())),
}
)
return out
def _serialize_io(value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None):
"""
Recursively build a JSON-serializable Python structure from `value`.
Tensors and DTensors become sanitized repr strings.
Tensors and DTensors become either sanitized repr strings, or are saved to disk as SafeTensors files and their
relative paths are recorded in the returned Python structure.
Lists/tuples/dicts are recursed into.
All memory addresses are replaced with a stable placeholder.
Args:
value: Any Python object, often including torch Tensors, lists, dicts, etc.
debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
`value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
files and store the relative path to that file in the `value` property.
path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
tensor value if `use_repr=False`.
Returns:
A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
"""
if isinstance(value, (list, tuple)):
return [_serialize_io(v) for v in value]
return [
_serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{i}")
for i, v in enumerate(value)
]
if isinstance(value, dict):
return {k: _serialize_io(v) for k, v in value.items()}
return {
k: _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{k}")
for k, v in value.items()
}
if hasattr(value, "_local_tensor"):
# DTensor-like handling, just use local tensor attribute
torch.set_printoptions(sci_mode=True)
val_repr = _repr_to_list(value)
out = {
"shape": repr(value._local_tensor.shape),
"dtype": repr(value._local_tensor.dtype),
"value": val_repr,
}
if value._local_tensor.dtype in {torch.float16, torch.float32, torch.bfloat16}:
value = value._local_tensor.clone()
out.update(
{
"mean": _sanitize_repr_for_diff(repr(value.mean())),
"std": _sanitize_repr_for_diff(repr(value.std())),
"min": _sanitize_repr_for_diff(repr(value.min())),
"max": _sanitize_repr_for_diff(repr(value.max())),
}
)
return out
return _serialize_tensor_like_io(
value._local_tensor, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value
)
if isinstance(value, torch.Tensor):
torch.set_printoptions(sci_mode=True)
val_repr = _repr_to_list(value)
out = {
"shape": repr(value.shape),
"dtype": repr(value.dtype),
"value": val_repr,
}
if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
out.update(
{
"mean": _sanitize_repr_for_diff(repr(value.mean())),
"std": _sanitize_repr_for_diff(repr(value.std())),
"min": _sanitize_repr_for_diff(repr(value.min())),
"max": _sanitize_repr_for_diff(repr(value.max())),
}
)
return out
return _serialize_tensor_like_io(value, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value)
return _sanitize_repr_for_diff(repr(value))
@ -199,7 +231,7 @@ def log_model_debug_trace(debug_path, model):
os.makedirs(debug_path, exist_ok=True)
base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree")
except Exception as e:
raise ValueError(f"Unexpected or existing debug_path={debug_path}. {e}")
raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
else:
base = model._debugger_module_dump_name + "_debug_tree"
@ -240,6 +272,7 @@ def _attach_debugger_logic(
model,
debug_path: Optional[str] = ".",
do_prune_layers: Optional[bool] = True,
use_repr: bool = True,
):
"""
Attaches a debugging wrapper to every module in the model.
@ -250,6 +283,9 @@ def _attach_debugger_logic(
model (`PreTrainedModel`, `nn.Module`): Model to wrap.
debug_path (`str`): Optional directory to dump debug JSON files.
do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers.
use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
`value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
files and store the relative path to that file in the `value` property.
"""
class_name = model.__class__.__name__
@ -258,6 +294,12 @@ def _attach_debugger_logic(
model._debugger_model_call_stack = []
model._debugger_module_dump_name = class_name # used for final JSON filename
if debug_path:
try:
os.makedirs(debug_path, exist_ok=True)
except Exception as e:
raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
def wrap_forward(module, full_path):
orig_forward = module.forward
@ -268,7 +310,12 @@ def _attach_debugger_logic(
dict_inputs = {k: dict_inputs[k] for k in dict_inputs if len(dict_inputs[k]) > 0}
node = {
"module_path": full_path,
"inputs": _serialize_io(dict_inputs),
"inputs": _serialize_io(
dict_inputs,
debug_path=debug_path,
use_repr=use_repr,
path_to_value=f"{full_path}_inputs",
),
"outputs": None,
"children": [],
}
@ -280,7 +327,12 @@ def _attach_debugger_logic(
if sum(1 for _ in module.named_children()) > 0:
node["outputs"] = None
else:
node["outputs"] = _serialize_io(out)
node["outputs"] = _serialize_io(
out,
debug_path=debug_path,
use_repr=use_repr,
path_to_value=f"{full_path}_outputs",
)
finished = model._debugger_model_call_stack.pop()
# prune empty vertices here as well (mostly empty children nodes)
@ -307,7 +359,12 @@ def _attach_debugger_logic(
if _is_rank_zero():
top_node = {
"module_path": f"{class_name} (top-level)",
"inputs": _serialize_io({"args": inps, "kwargs": kws}),
"inputs": _serialize_io(
{"args": inps, "kwargs": kws},
debug_path=debug_path,
use_repr=use_repr,
path_to_value=f"{class_name}_inputs",
),
"outputs": None,
"children": [],
}
@ -315,7 +372,12 @@ def _attach_debugger_logic(
out = real_top_forward(*inps, **kws)
if _is_rank_zero() and model._debugger_model_call_stack:
top_node["outputs"] = _serialize_io(out)
top_node["outputs"] = _serialize_io(
out,
debug_path=debug_path,
use_repr=use_repr,
path_to_value=f"{class_name}_outputs",
)
finished = model._debugger_model_call_stack.pop()
model._call_tree["inputs"] = finished["inputs"]
model._call_tree["outputs"] = finished["outputs"]
@ -335,11 +397,21 @@ def _attach_debugger_logic(
@requires(backends=("torch",))
@contextmanager
def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_prune_layers: Optional[bool] = True):
def model_addition_debugger_context(
model,
debug_path: Optional[str] = None,
do_prune_layers: Optional[bool] = True,
use_repr: Optional[bool] = True,
):
"""
# Model addition debugger - context manager for model adders
This context manager is a power user tool intended for model adders.
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested JSON file.
If `use_repr=True` (the default), the JSON file will record a `repr()`-ized version of the tensors as a list of
strings. If `use_repr=False`, the full tensors will be stored in spearate SafeTensors files and the JSON file will
provide a relative path to that file.
To note, this context manager enforces `torch.no_grad()`.
## Usage
@ -348,10 +420,10 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_
```python
import torch
from PIL import Image
import requests
from transformers import LlavaProcessor, LlavaForConditionalGeneration
from transformers.model_debugging_utils import model_addition_debugger_context
from transformers import LlavaProcessor, LlavaForConditionalGeneration, model_addition_debugger_context
torch.random.manual_seed(673)
# load pretrained model and processor
@ -376,7 +448,7 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_
"""
orig_forwards = {m: m.forward for _, m in model.named_modules()}
orig_forwards[model] = model.forward
_attach_debugger_logic(model, debug_path, do_prune_layers)
_attach_debugger_logic(model, debug_path, do_prune_layers, use_repr)
try:
yield model
finally: