mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Adds use_repr to model_addition_debugger_context (#37984)
* Adds use_repr to model_addition_debugger_context * Updating docs for use_repr option
This commit is contained in:
parent
38f9c5b15b
commit
9eb0a37c9e
@ -16,7 +16,8 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Model debugging toolboxes
|
||||
|
||||
This page lists all the debugging and model adding tools used by the library, as well as the utility functions it provides for it.
|
||||
This page lists all the debugging and model adding tools used by the library, as well as the utility functions it
|
||||
provides for it.
|
||||
|
||||
Most of those are only useful if you are adding new models in the library.
|
||||
|
||||
@ -26,13 +27,14 @@ Most of those are only useful if you are adding new models in the library.
|
||||
|
||||
### Model addition debugger - context manager for model adders
|
||||
|
||||
This context manager is a power user tool intended for model adders.
|
||||
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
|
||||
To note, this context manager enforces `torch.no_grad()`.
|
||||
This context manager is a power user tool intended for model adders. It tracks all forward calls within a model forward
|
||||
and logs a slice of each input and output on a nested JSON. To note, this context manager enforces `torch.no_grad()`.
|
||||
|
||||
### Rationale
|
||||
|
||||
Because when porting models to transformers, even from python to python, model adders often have to do a lot of manual operations, involving saving and loading tensors, comparing dtypes, etc. This small tool can hopefully shave off some time.
|
||||
When porting models to transformers, even from python to python, model adders often have to do a lot of manual
|
||||
operations, involving saving and loading tensors, comparing dtypes, etc. This small tool can hopefully shave off some
|
||||
time.
|
||||
|
||||
### Usage
|
||||
|
||||
@ -62,10 +64,10 @@ inputs = processor(text=prompt, images=random_image, return_tensors="pt")
|
||||
|
||||
# call forward method (not .generate!)
|
||||
with model_addition_debugger_context(
|
||||
model,
|
||||
debug_path="optional_path_to_your_directory",
|
||||
do_prune_layers=False # This will output ALL the layers of a model.
|
||||
):
|
||||
model,
|
||||
debug_path="optional_path_to_your_directory",
|
||||
do_prune_layers=False # This will output ALL the layers of a model.
|
||||
):
|
||||
output = model.forward(**inputs)
|
||||
|
||||
```
|
||||
@ -73,8 +75,8 @@ with model_addition_debugger_context(
|
||||
|
||||
### Reading results
|
||||
|
||||
The debugger generates two files from the forward call, both with the same base name,
|
||||
but ending either with `_SUMMARY.json` or with `_FULL_TENSORS.json`.
|
||||
The debugger generates two files from the forward call, both with the same base name, but ending either with
|
||||
`_SUMMARY.json` or with `_FULL_TENSORS.json`.
|
||||
|
||||
The first one will contain a summary of each module's _input_ and _output_ tensor values and shapes.
|
||||
|
||||
@ -142,8 +144,8 @@ The first one will contain a summary of each module's _input_ and _output_ tenso
|
||||
{ ... and so on
|
||||
```
|
||||
|
||||
The `_FULL_TENSORS.json` file will display a full view of all tensors, which is useful
|
||||
for comparing two files.
|
||||
The `_FULL_TENSORS.json` file will display a full view of all tensors, which is useful for comparing two files.
|
||||
|
||||
```json
|
||||
"pixel_values": {
|
||||
"shape": "torch.Size([1, 5, 576, 588])",
|
||||
@ -196,9 +198,38 @@ for comparing two files.
|
||||
},
|
||||
```
|
||||
|
||||
#### Saving tensors to disk
|
||||
|
||||
Some model adders may benefit from logging full tensor values to disk to support, for example, numerical analysis
|
||||
across implementations.
|
||||
|
||||
Set `use_repr=False` to write tensors to disk using [SafeTensors](https://huggingface.co/docs/safetensors/en/index).
|
||||
|
||||
```python
|
||||
with model_addition_debugger_context(
|
||||
model,
|
||||
debug_path="optional_path_to_your_directory",
|
||||
do_prune_layers=False,
|
||||
use_repr=False, # Defaults to True
|
||||
):
|
||||
output = model.forward(**inputs)
|
||||
```
|
||||
|
||||
When using `use_repr=False`, tensors are written to the same disk location as the `_SUMMARY.json` and
|
||||
`_FULL_TENSORS.json` files. The `value` property of entries in the `_FULL_TENSORS.json` file will contain a relative
|
||||
path reference to the associated `.safetensors` file. Each tensor is written to its own file as the `data` property of
|
||||
the state dictionary. File names are constructed using the `module_path` as a prefix with a few possible postfixes that
|
||||
are built recursively.
|
||||
|
||||
* Module inputs are denoted with the `_inputs` and outputs by `_outputs`.
|
||||
* `list` and `tuple` instances, such as `args` or function return values, will be postfixed with `_{index}`.
|
||||
* `dict` instances will be postfixed with `_{key}`.
|
||||
|
||||
### Comparing between implementations
|
||||
|
||||
Once the forward passes of two models have been traced by the debugger, one can compare the `json` output files. See below: we can see slight differences between these two implementations' key projection layer. Inputs are mostly identical, but not quite. Looking through the file differences makes it easier to pinpoint which layer is wrong.
|
||||
Once the forward passes of two models have been traced by the debugger, one can compare the `json` output files. See
|
||||
below: we can see slight differences between these two implementations' key projection layer. Inputs are mostly
|
||||
identical, but not quite. Looking through the file differences makes it easier to pinpoint which layer is wrong.
|
||||
|
||||
|
||||

|
||||
@ -206,8 +237,13 @@ Once the forward passes of two models have been traced by the debugger, one can
|
||||
|
||||
### Limitations and scope
|
||||
|
||||
This feature will only work for torch-based models, and would require more work and case-by-case approach for say `jax`-based models that are usually compiled. Models relying heavily on external kernel calls may work, but trace will probably miss some things. Regardless, any python implementation that aims at mimicking another implementation can be traced once instead of reran N times with breakpoints.
|
||||
This feature will only work for torch-based models, and would require more work and case-by-case approach for say
|
||||
`jax`-based models that are usually compiled. Models relying heavily on external kernel calls may work, but trace will
|
||||
probably miss some things. Regardless, any python implementation that aims at mimicking another implementation can be
|
||||
traced once instead of reran N times with breakpoints.
|
||||
|
||||
If you pass `do_prune_layers=False` to your model debugger, ALL the layers will be outputted to `json`. Else, only the first and last layer will be shown. This is useful when some layers (typically cross-attention) appear only after N layers.
|
||||
If you pass `do_prune_layers=False` to your model debugger, ALL the layers will be outputted to `json`. Else, only the
|
||||
first and last layer will be shown. This is useful when some layers (typically cross-attention) appear only after N
|
||||
layers.
|
||||
|
||||
[[autodoc]] model_addition_debugger_context
|
||||
|
@ -21,6 +21,8 @@ from contextlib import contextmanager, redirect_stdout
|
||||
from io import StringIO
|
||||
from typing import Optional
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from transformers.utils.import_utils import requires
|
||||
|
||||
from .utils import is_torch_available
|
||||
@ -65,64 +67,94 @@ def _dtensor_repr(x):
|
||||
return "DTensor(non-rank0)"
|
||||
|
||||
|
||||
def _serialize_io(value):
|
||||
def _serialize_tensor_like_io(
|
||||
value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Converts Tensors and DTensors to a JSON-serializable dictionary representation.
|
||||
|
||||
Args:
|
||||
value: Any Python object, often including torch Tensors, lists, dicts, etc.
|
||||
debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
|
||||
use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensor as the
|
||||
`value` property in the asscoiated FULL_TENSORS.json file, or to store the full tensors in separate
|
||||
SafeTensors file and store the relative path to that file in the `value` property in the dictionary.
|
||||
path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
|
||||
tensor value if `use_repr=False`.
|
||||
|
||||
Returns:
|
||||
A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
|
||||
"""
|
||||
torch.set_printoptions(sci_mode=True)
|
||||
|
||||
if use_repr:
|
||||
value_out = _repr_to_list(value)
|
||||
elif path_to_value:
|
||||
if not path_to_value.endswith(".safetensors"):
|
||||
path_to_value += ".safetensors"
|
||||
|
||||
filepath = os.path.join(debug_path, path_to_value) if debug_path else path_to_value
|
||||
save_file({"data": value.contiguous().detach().cpu()}, filepath)
|
||||
value_out = f"./{path_to_value}"
|
||||
else:
|
||||
raise ValueError(f"{use_repr=} and {path_to_value=} cannot both be falsy.")
|
||||
|
||||
out = {
|
||||
"shape": repr(value.shape),
|
||||
"dtype": repr(value.dtype),
|
||||
"value": value_out,
|
||||
}
|
||||
if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
|
||||
out.update(
|
||||
{
|
||||
"mean": _sanitize_repr_for_diff(repr(value.mean())),
|
||||
"std": _sanitize_repr_for_diff(repr(value.std())),
|
||||
"min": _sanitize_repr_for_diff(repr(value.min())),
|
||||
"max": _sanitize_repr_for_diff(repr(value.max())),
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _serialize_io(value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None):
|
||||
"""
|
||||
Recursively build a JSON-serializable Python structure from `value`.
|
||||
Tensors and DTensors become sanitized repr strings.
|
||||
Tensors and DTensors become either sanitized repr strings, or are saved to disk as SafeTensors files and their
|
||||
relative paths are recorded in the returned Python structure.
|
||||
Lists/tuples/dicts are recursed into.
|
||||
All memory addresses are replaced with a stable placeholder.
|
||||
|
||||
Args:
|
||||
value: Any Python object, often including torch Tensors, lists, dicts, etc.
|
||||
debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
|
||||
use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
|
||||
`value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
|
||||
files and store the relative path to that file in the `value` property.
|
||||
path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
|
||||
tensor value if `use_repr=False`.
|
||||
|
||||
Returns:
|
||||
A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
|
||||
"""
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_serialize_io(v) for v in value]
|
||||
return [
|
||||
_serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{i}")
|
||||
for i, v in enumerate(value)
|
||||
]
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {k: _serialize_io(v) for k, v in value.items()}
|
||||
return {
|
||||
k: _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{k}")
|
||||
for k, v in value.items()
|
||||
}
|
||||
|
||||
if hasattr(value, "_local_tensor"):
|
||||
# DTensor-like handling, just use local tensor attribute
|
||||
torch.set_printoptions(sci_mode=True)
|
||||
val_repr = _repr_to_list(value)
|
||||
out = {
|
||||
"shape": repr(value._local_tensor.shape),
|
||||
"dtype": repr(value._local_tensor.dtype),
|
||||
"value": val_repr,
|
||||
}
|
||||
if value._local_tensor.dtype in {torch.float16, torch.float32, torch.bfloat16}:
|
||||
value = value._local_tensor.clone()
|
||||
out.update(
|
||||
{
|
||||
"mean": _sanitize_repr_for_diff(repr(value.mean())),
|
||||
"std": _sanitize_repr_for_diff(repr(value.std())),
|
||||
"min": _sanitize_repr_for_diff(repr(value.min())),
|
||||
"max": _sanitize_repr_for_diff(repr(value.max())),
|
||||
}
|
||||
)
|
||||
return out
|
||||
return _serialize_tensor_like_io(
|
||||
value._local_tensor, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value
|
||||
)
|
||||
|
||||
if isinstance(value, torch.Tensor):
|
||||
torch.set_printoptions(sci_mode=True)
|
||||
val_repr = _repr_to_list(value)
|
||||
out = {
|
||||
"shape": repr(value.shape),
|
||||
"dtype": repr(value.dtype),
|
||||
"value": val_repr,
|
||||
}
|
||||
if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
|
||||
out.update(
|
||||
{
|
||||
"mean": _sanitize_repr_for_diff(repr(value.mean())),
|
||||
"std": _sanitize_repr_for_diff(repr(value.std())),
|
||||
"min": _sanitize_repr_for_diff(repr(value.min())),
|
||||
"max": _sanitize_repr_for_diff(repr(value.max())),
|
||||
}
|
||||
)
|
||||
return out
|
||||
return _serialize_tensor_like_io(value, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value)
|
||||
|
||||
return _sanitize_repr_for_diff(repr(value))
|
||||
|
||||
@ -199,7 +231,7 @@ def log_model_debug_trace(debug_path, model):
|
||||
os.makedirs(debug_path, exist_ok=True)
|
||||
base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unexpected or existing debug_path={debug_path}. {e}")
|
||||
raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
|
||||
else:
|
||||
base = model._debugger_module_dump_name + "_debug_tree"
|
||||
|
||||
@ -240,6 +272,7 @@ def _attach_debugger_logic(
|
||||
model,
|
||||
debug_path: Optional[str] = ".",
|
||||
do_prune_layers: Optional[bool] = True,
|
||||
use_repr: bool = True,
|
||||
):
|
||||
"""
|
||||
Attaches a debugging wrapper to every module in the model.
|
||||
@ -250,6 +283,9 @@ def _attach_debugger_logic(
|
||||
model (`PreTrainedModel`, `nn.Module`): Model to wrap.
|
||||
debug_path (`str`): Optional directory to dump debug JSON files.
|
||||
do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers.
|
||||
use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
|
||||
`value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
|
||||
files and store the relative path to that file in the `value` property.
|
||||
"""
|
||||
class_name = model.__class__.__name__
|
||||
|
||||
@ -258,6 +294,12 @@ def _attach_debugger_logic(
|
||||
model._debugger_model_call_stack = []
|
||||
model._debugger_module_dump_name = class_name # used for final JSON filename
|
||||
|
||||
if debug_path:
|
||||
try:
|
||||
os.makedirs(debug_path, exist_ok=True)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
|
||||
|
||||
def wrap_forward(module, full_path):
|
||||
orig_forward = module.forward
|
||||
|
||||
@ -268,7 +310,12 @@ def _attach_debugger_logic(
|
||||
dict_inputs = {k: dict_inputs[k] for k in dict_inputs if len(dict_inputs[k]) > 0}
|
||||
node = {
|
||||
"module_path": full_path,
|
||||
"inputs": _serialize_io(dict_inputs),
|
||||
"inputs": _serialize_io(
|
||||
dict_inputs,
|
||||
debug_path=debug_path,
|
||||
use_repr=use_repr,
|
||||
path_to_value=f"{full_path}_inputs",
|
||||
),
|
||||
"outputs": None,
|
||||
"children": [],
|
||||
}
|
||||
@ -280,7 +327,12 @@ def _attach_debugger_logic(
|
||||
if sum(1 for _ in module.named_children()) > 0:
|
||||
node["outputs"] = None
|
||||
else:
|
||||
node["outputs"] = _serialize_io(out)
|
||||
node["outputs"] = _serialize_io(
|
||||
out,
|
||||
debug_path=debug_path,
|
||||
use_repr=use_repr,
|
||||
path_to_value=f"{full_path}_outputs",
|
||||
)
|
||||
|
||||
finished = model._debugger_model_call_stack.pop()
|
||||
# prune empty vertices here as well (mostly empty children nodes)
|
||||
@ -307,7 +359,12 @@ def _attach_debugger_logic(
|
||||
if _is_rank_zero():
|
||||
top_node = {
|
||||
"module_path": f"{class_name} (top-level)",
|
||||
"inputs": _serialize_io({"args": inps, "kwargs": kws}),
|
||||
"inputs": _serialize_io(
|
||||
{"args": inps, "kwargs": kws},
|
||||
debug_path=debug_path,
|
||||
use_repr=use_repr,
|
||||
path_to_value=f"{class_name}_inputs",
|
||||
),
|
||||
"outputs": None,
|
||||
"children": [],
|
||||
}
|
||||
@ -315,7 +372,12 @@ def _attach_debugger_logic(
|
||||
|
||||
out = real_top_forward(*inps, **kws)
|
||||
if _is_rank_zero() and model._debugger_model_call_stack:
|
||||
top_node["outputs"] = _serialize_io(out)
|
||||
top_node["outputs"] = _serialize_io(
|
||||
out,
|
||||
debug_path=debug_path,
|
||||
use_repr=use_repr,
|
||||
path_to_value=f"{class_name}_outputs",
|
||||
)
|
||||
finished = model._debugger_model_call_stack.pop()
|
||||
model._call_tree["inputs"] = finished["inputs"]
|
||||
model._call_tree["outputs"] = finished["outputs"]
|
||||
@ -335,11 +397,21 @@ def _attach_debugger_logic(
|
||||
|
||||
@requires(backends=("torch",))
|
||||
@contextmanager
|
||||
def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_prune_layers: Optional[bool] = True):
|
||||
def model_addition_debugger_context(
|
||||
model,
|
||||
debug_path: Optional[str] = None,
|
||||
do_prune_layers: Optional[bool] = True,
|
||||
use_repr: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
# Model addition debugger - context manager for model adders
|
||||
This context manager is a power user tool intended for model adders.
|
||||
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
|
||||
|
||||
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested JSON file.
|
||||
If `use_repr=True` (the default), the JSON file will record a `repr()`-ized version of the tensors as a list of
|
||||
strings. If `use_repr=False`, the full tensors will be stored in spearate SafeTensors files and the JSON file will
|
||||
provide a relative path to that file.
|
||||
|
||||
To note, this context manager enforces `torch.no_grad()`.
|
||||
|
||||
## Usage
|
||||
@ -348,10 +420,10 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import LlavaProcessor, LlavaForConditionalGeneration
|
||||
from transformers.model_debugging_utils import model_addition_debugger_context
|
||||
from transformers import LlavaProcessor, LlavaForConditionalGeneration, model_addition_debugger_context
|
||||
|
||||
torch.random.manual_seed(673)
|
||||
|
||||
# load pretrained model and processor
|
||||
@ -376,7 +448,7 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_
|
||||
"""
|
||||
orig_forwards = {m: m.forward for _, m in model.named_modules()}
|
||||
orig_forwards[model] = model.forward
|
||||
_attach_debugger_logic(model, debug_path, do_prune_layers)
|
||||
_attach_debugger_logic(model, debug_path, do_prune_layers, use_repr)
|
||||
try:
|
||||
yield model
|
||||
finally:
|
||||
|
Loading…
Reference in New Issue
Block a user