Add model visual debugger (#36798)

* draft of model tracer visualiser

* add context manager in addition to decorator

* add debug utils to init

* move model debugging utils to dedicated file

* add documentation

* protect some imports

* format

* move and protect imports

* format

* doc: improve errors in case of broken dummy imports.

* format

* use automatic torch backend

* update doc

* fix backend

* (TEMP) move to dummies while backend wait

* update documentation

* doc
This commit is contained in:
Pablo Montalvo 2025-03-20 17:37:29 +01:00 committed by GitHub
parent 6515c25953
commit 1d3f35f30a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 432 additions and 2 deletions

View File

@ -0,0 +1,71 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Model debugging toolboxes
This page lists all the debugging and model adding tools used by the library, as well as the utility functions it provides for it.
Most of those are only useful if you are adding new models in the library.
## Model addition debuggers
### Model addition debugger - context manager for model adders
This context manager is a power user tool intended for model adders.
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
To note, this context manager enforces `torch.inference_mode()`.
### Rationale
Because when porting models to transformers, even from python to python, model adders often have to do a lot of manual operations, involving saving and loading tensors, comparing dtypes, etc. This small tool can hopefully shave off some time.
### Usage
Add this context manager as follows to debug a model:
```python
import torch
from PIL import Image
import requests
from transformers import LlavaProcessor, LlavaForConditionalGeneration
torch.random.manual_seed(673)
# load pretrained model and processor
model_id = "llava-hf/llava-1.5-7b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True)
# create random image input
random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy())
# prompt
prompt = "<image>Describe this image."
# process inputs
inputs = processor(text=prompt, images=random_image, return_tensors="pt")
# call forward method (not .generate!)
with model_addition_debugger_context(model, "optional_path_to_your_output_file.json"):
output = model.forward(**inputs)
```
[[autodoc]] utils.model_addition_debugger
[[autodoc]] utils.model_addition_debugger_context

View File

@ -1376,6 +1376,10 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["model_debugging_utils"] = [
"model_addition_debugger",
"model_addition_debugger_context",
]
_import_structure["activations"] = []
_import_structure["cache_utils"] = [
"Cache",
@ -6605,6 +6609,7 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import *
else:
# Debugging
from .cache_utils import (
Cache,
CacheConfig,
@ -6690,6 +6695,10 @@ if TYPE_CHECKING:
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)
from .model_debugging_utils import (
model_addition_debugger,
model_addition_debugger_context,
)
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
from .modeling_utils import PreTrainedModel
from .models.albert import (

View File

@ -0,0 +1,329 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import json
import os
import re
from contextlib import contextmanager
from transformers.utils.import_utils import export
from .utils import is_torch_available
if is_torch_available():
import torch
import torch.distributed.tensor
from torch import nn
from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.get_logger(__name__)
# Note to code inspectors: this toolbox is intended for people who add models to `transformers`.
_torch_distributed_available = torch.distributed.is_available()
def _is_rank_zero():
"""Return True if rank=0 or we aren't running distributed."""
if not (_torch_distributed_available and torch.distributed.is_initialized()):
return True
return torch.distributed.get_rank() == 0
MEMORY_ADDRESS_REGEX = re.compile(r"object at 0x[0-9A-Fa-f]+")
def _sanitize_repr_for_diff(x_str: str) -> str:
"""
Replace memory addresses in an object's repr with a stable placeholder
so that beautiful JSON diffs won't be ruined by ephemeral addresses.
"""
return MEMORY_ADDRESS_REGEX.sub("object at 0xXXXXXXXX", x_str)
def _dtensor_repr(x):
"""Return a stable string representation for a DTensor-like object."""
if _is_rank_zero():
return f"DTensor (rank0) -> {repr(x._local_tensor)}"
return "DTensor(non-rank0)"
def _serialize_io(value):
"""
Recursively build a JSON-serializable Python structure from `value`.
Tensors and DTensors become sanitized repr strings.
Lists/tuples/dicts are recursed into.
All memory addresses are replaced with a stable placeholder.
Args:
value: Any Python object, often including torch Tensors, lists, dicts, etc.
Returns:
A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
"""
if isinstance(value, (list, tuple)):
return [_serialize_io(v) for v in value]
if isinstance(value, dict):
return {k: _serialize_io(v) for k, v in value.items()}
if hasattr(value, "_local_tensor"):
# DTensor-like handling, just use local tensor attribute
return {
"shape": repr(value._local_tensor.shape),
"dtype": repr(value._local_tensor.dtype),
"value": _sanitize_repr_for_diff(repr(value)),
}
if isinstance(value, torch.Tensor):
# standard PyTorch Tensor
# return also the shape of such
return {"shape": repr(value.shape), "dtype": repr(value.dtype), "value": _sanitize_repr_for_diff(repr(value))}
# fallback for everything else (bool, int, float, None, or custom class)
return _sanitize_repr_for_diff(repr(value))
def prune_outputs_if_children(node):
# if there are children, remove this node's "outputs"
# so we only see outputs at the leaf level
if node.get("children"):
node.pop("outputs", None)
for child in node["children"]:
prune_outputs_if_children(child)
def log_model_debug_trace(debug_path, model):
if debug_path:
try:
os.makedirs(debug_path, exist_ok=False)
output_path = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree.json")
except Exception as e:
raise ValueError(f"Unexpected or existing debug_path={debug_path}. {e}")
else:
output_path = model._debugger_module_dump_name + "_debug_tree.json"
logger.info(f"Writing model trace at {output_path}")
with open(output_path, "w") as outfile:
prune_outputs_if_children(model._call_tree)
json.dump(model._call_tree, outfile, indent=2)
def _attach_debugger_logic(model, class_name, debug_path: str):
# Prepare data structures on the model object
model._call_tree = {"module_path": class_name, "inputs": None, "outputs": None, "children": []}
model._debugger_model_call_stack = []
model._debugger_module_dump_name = class_name # used for final JSON filename
def wrap_forward(module, full_path):
orig_forward = module.forward
@functools.wraps(orig_forward)
def wrapped_forward(*inps, **kws):
if _is_rank_zero():
dict_inputs = {"args": inps, "kwargs": kws}
dict_inputs = {k: dict_inputs[k] for k in dict_inputs if len(dict_inputs[k]) > 0}
node = {
"module_path": full_path,
"inputs": _serialize_io(dict_inputs),
"outputs": None,
"children": [],
}
model._debugger_model_call_stack.append(node)
with torch.inference_mode():
out = orig_forward(*inps, **kws)
if _is_rank_zero():
if sum(1 for _ in module.named_children()) > 0:
node["outputs"] = None
else:
node["outputs"] = _serialize_io(out)
finished = model._debugger_model_call_stack.pop()
# prune empty vertices here as well (mostly empty children nodes)
if not finished["children"]:
finished.pop("children")
if model._debugger_model_call_stack:
model._debugger_model_call_stack[-1]["children"].append(finished)
return out
module.forward = wrapped_forward
# wrap all submodules
for name, submodule in model.named_modules():
if name == "":
continue
wrap_forward(submodule, f"{class_name}.{name}")
# wrap top-level forward
real_top_forward = model.forward
@functools.wraps(real_top_forward)
def top_wrapped_forward(*inps, **kws):
if _is_rank_zero():
top_node = {
"module_path": f"{class_name} (top-level)",
"inputs": _serialize_io({"args": inps, "kwargs": kws}),
"outputs": None,
"children": [],
}
model._debugger_model_call_stack.append(top_node)
out = real_top_forward(*inps, **kws)
if _is_rank_zero() and model._debugger_model_call_stack:
top_node["outputs"] = _serialize_io(out)
finished = model._debugger_model_call_stack.pop()
model._call_tree["inputs"] = finished["inputs"]
model._call_tree["outputs"] = finished["outputs"]
model._call_tree["children"] = finished["children"]
# prune empty stuff for visibility
[model._call_tree.pop(k, None) for k in list(model._call_tree.keys()) if not model._call_tree[k]]
return out
model.forward = top_wrapped_forward
# Final hook for writing JSON on forward-end
def final_hook(_, inputs, outputs):
if _is_rank_zero() and model._debugger_model_call_stack:
finished = model._debugger_model_call_stack.pop()
model._call_tree["inputs"] = finished["inputs"]
model._call_tree["outputs"] = finished["outputs"]
model._call_tree["children"] = finished["children"]
if _is_rank_zero():
log_model_debug_trace(debug_path=debug_path, model=model)
model.register_forward_hook(final_hook)
# Optionally also for a couple possible hooks that have specific names. It should be just one.
# This means modules that are not typically called "forward" within the model. But we should not need to recurse
# through them.
possible_model_calls = ["language_model", "model"]
for model_call in possible_model_calls:
this_model_call = getattr(model, model_call, None)
if this_model_call and isinstance(this_model_call, (nn.Module, PreTrainedModel)):
this_model_call.register_forward_hook(final_hook)
break # exit the loop after finding one (unsure, but should be just one call.)
@export(backends=("torch",))
def model_addition_debugger(cls):
"""
# Model addition debugger - a model adder tracer
This decorator is a power user tool intended for model adders.
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
To note, this decorator enforces `torch.inference_mode()`.
## Usage
add decorator to your model class
```python
from ...modeling_utils import model_addition_debugger
@model_addition_debugger
class MyModel(nn.Module) # Can inherit from PreTrainedModel too
# ... nothing else changes
```
Then, in a separate script (example is for Llava)
```python
import torch
from PIL import Image
import requests
from transformers import LlavaProcessor, LlavaForConditionalGeneration
torch.random.manual_seed(673)
# load pretrained model and processor
model_id = "llava-hf/llava-1.5-7b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True)
# create random image input
random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy())
# prompt
prompt = "<image>Describe this image."
# process inputs
inputs = processor(text=prompt, images=random_image, return_tensors="pt")
# call forward method (not .generate!)
with torch.no_grad():
output = model.forward(**inputs)
```
"""
orig_init = cls.__init__
@functools.wraps(cls.__init__)
def wrapped_init(self, *args, **kwargs):
orig_init(self, *args, **kwargs)
_attach_debugger_logic(self, cls.__name__)
cls.__init__ = wrapped_init
return cls
@export(backends=("torch",))
@contextmanager
def model_addition_debugger_context(model, debug_path: str = None):
"""
# Model addition debugger - context manager for model adders
This context manager is a power user tool intended for model adders.
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
To note, this context manager enforces `torch.inference_mode()`.
## Usage
add the context manager to a model to debug
```python
import torch
from PIL import Image
import requests
from transformers import LlavaProcessor, LlavaForConditionalGeneration
torch.random.manual_seed(673)
# load pretrained model and processor
model_id = "llava-hf/llava-1.5-7b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True)
# create random image input
random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy())
# prompt
prompt = "<image>Describe this image."
# process inputs
inputs = processor(text=prompt, images=random_image, return_tensors="pt")
# call forward method (not .generate!)
with model_addition_debugger_context(model):
output = model.forward(**inputs)
```
"""
_attach_debugger_logic(model, model.__class__.__name__, debug_path)
try:
yield model
finally:
pass

View File

@ -538,6 +538,14 @@ def convert_and_export_with_cache(*args, **kwargs):
requires_backends(convert_and_export_with_cache, ["torch"])
def model_addition_debugger(*args, **kwargs):
requires_backends(model_addition_debugger, ["torch"])
def model_addition_debugger_context(*args, **kwargs):
requires_backends(model_addition_debugger_context, ["torch"])
ROPE_INIT_FUNCTIONS = None

View File

@ -222,10 +222,23 @@ def check_dummies(overwrite: bool = False):
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
f.write(dummy_files[backend])
else:
# Temporary fix to help people identify which objects introduced are not correctly protected.
for _actual, _dummy in zip(
actual_dummies["torch"].split("class"), dummy_files["torch"].split("class")
):
if _actual != _dummy:
actual_broken = _actual
dummy_broken = _dummy
break
raise ValueError(
"The main __init__ has objects that are not present in "
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
"to fix this."
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py.\n"
f" It is likely the following objects are responsible, see these excerpts: \n"
f"---------------------------------- Actual -------------------------------------\n"
f" \n {actual_broken} \n"
f"---------------------------------- Dummy -------------------------------------\n"
f" \n {dummy_broken} \n"
"Run `make fix-copies` to fix this."
)