From 1d3f35f30aa8d91996ba533df5f93d5b769886ce Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Thu, 20 Mar 2025 17:37:29 +0100 Subject: [PATCH] Add model visual debugger (#36798) * draft of model tracer visualiser * add context manager in addition to decorator * add debug utils to init * move model debugging utils to dedicated file * add documentation * protect some imports * format * move and protect imports * format * doc: improve errors in case of broken dummy imports. * format * use automatic torch backend * update doc * fix backend * (TEMP) move to dummies while backend wait * update documentation * doc --- .../en/internal/model_debugging_utils.md | 71 ++++ src/transformers/__init__.py | 9 + src/transformers/model_debugging_utils.py | 329 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 8 + utils/check_dummies.py | 17 +- 5 files changed, 432 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/internal/model_debugging_utils.md create mode 100644 src/transformers/model_debugging_utils.py diff --git a/docs/source/en/internal/model_debugging_utils.md b/docs/source/en/internal/model_debugging_utils.md new file mode 100644 index 00000000000..c5708aa8e65 --- /dev/null +++ b/docs/source/en/internal/model_debugging_utils.md @@ -0,0 +1,71 @@ + + +# Model debugging toolboxes + +This page lists all the debugging and model adding tools used by the library, as well as the utility functions it provides for it. + +Most of those are only useful if you are adding new models in the library. + + +## Model addition debuggers + + +### Model addition debugger - context manager for model adders + +This context manager is a power user tool intended for model adders. +It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json. +To note, this context manager enforces `torch.inference_mode()`. + +### Rationale + +Because when porting models to transformers, even from python to python, model adders often have to do a lot of manual operations, involving saving and loading tensors, comparing dtypes, etc. This small tool can hopefully shave off some time. + +### Usage + +Add this context manager as follows to debug a model: + +```python +import torch +from PIL import Image +import requests +from transformers import LlavaProcessor, LlavaForConditionalGeneration +torch.random.manual_seed(673) + +# load pretrained model and processor +model_id = "llava-hf/llava-1.5-7b-hf" +processor = LlavaProcessor.from_pretrained(model_id) +model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True) + +# create random image input +random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy()) + +# prompt +prompt = "Describe this image." + +# process inputs +inputs = processor(text=prompt, images=random_image, return_tensors="pt") + +# call forward method (not .generate!) +with model_addition_debugger_context(model, "optional_path_to_your_output_file.json"): + output = model.forward(**inputs) + +``` + + +[[autodoc]] utils.model_addition_debugger + +[[autodoc]] utils.model_addition_debugger_context diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8c7a608b090..56bbcb76f4d 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1376,6 +1376,10 @@ except OptionalDependencyNotAvailable: _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["model_debugging_utils"] = [ + "model_addition_debugger", + "model_addition_debugger_context", + ] _import_structure["activations"] = [] _import_structure["cache_utils"] = [ "Cache", @@ -6605,6 +6609,7 @@ if TYPE_CHECKING: except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * else: + # Debugging from .cache_utils import ( Cache, CacheConfig, @@ -6690,6 +6695,10 @@ if TYPE_CHECKING: TorchExportableModuleWithStaticCache, convert_and_export_with_cache, ) + from .model_debugging_utils import ( + model_addition_debugger, + model_addition_debugger_context, + ) from .modeling_rope_utils import ROPE_INIT_FUNCTIONS from .modeling_utils import PreTrainedModel from .models.albert import ( diff --git a/src/transformers/model_debugging_utils.py b/src/transformers/model_debugging_utils.py new file mode 100644 index 00000000000..d45586aee1b --- /dev/null +++ b/src/transformers/model_debugging_utils.py @@ -0,0 +1,329 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import json +import os +import re +from contextlib import contextmanager + +from transformers.utils.import_utils import export + +from .utils import is_torch_available + + +if is_torch_available(): + import torch + import torch.distributed.tensor + from torch import nn + + from .modeling_utils import PreTrainedModel + +from .utils import logging + + +logger = logging.get_logger(__name__) + +# Note to code inspectors: this toolbox is intended for people who add models to `transformers`. +_torch_distributed_available = torch.distributed.is_available() + + +def _is_rank_zero(): + """Return True if rank=0 or we aren't running distributed.""" + if not (_torch_distributed_available and torch.distributed.is_initialized()): + return True + return torch.distributed.get_rank() == 0 + + +MEMORY_ADDRESS_REGEX = re.compile(r"object at 0x[0-9A-Fa-f]+") + + +def _sanitize_repr_for_diff(x_str: str) -> str: + """ + Replace memory addresses in an object's repr with a stable placeholder + so that beautiful JSON diffs won't be ruined by ephemeral addresses. + """ + return MEMORY_ADDRESS_REGEX.sub("object at 0xXXXXXXXX", x_str) + + +def _dtensor_repr(x): + """Return a stable string representation for a DTensor-like object.""" + if _is_rank_zero(): + return f"DTensor (rank0) -> {repr(x._local_tensor)}" + return "DTensor(non-rank0)" + + +def _serialize_io(value): + """ + Recursively build a JSON-serializable Python structure from `value`. + Tensors and DTensors become sanitized repr strings. + Lists/tuples/dicts are recursed into. + All memory addresses are replaced with a stable placeholder. + + Args: + value: Any Python object, often including torch Tensors, lists, dicts, etc. + + Returns: + A nested Python structure (list, dict, or sanitized string) that is safe to json.dump. + """ + if isinstance(value, (list, tuple)): + return [_serialize_io(v) for v in value] + + if isinstance(value, dict): + return {k: _serialize_io(v) for k, v in value.items()} + + if hasattr(value, "_local_tensor"): + # DTensor-like handling, just use local tensor attribute + return { + "shape": repr(value._local_tensor.shape), + "dtype": repr(value._local_tensor.dtype), + "value": _sanitize_repr_for_diff(repr(value)), + } + + if isinstance(value, torch.Tensor): + # standard PyTorch Tensor + # return also the shape of such + return {"shape": repr(value.shape), "dtype": repr(value.dtype), "value": _sanitize_repr_for_diff(repr(value))} + + # fallback for everything else (bool, int, float, None, or custom class) + return _sanitize_repr_for_diff(repr(value)) + + +def prune_outputs_if_children(node): + # if there are children, remove this node's "outputs" + # so we only see outputs at the leaf level + if node.get("children"): + node.pop("outputs", None) + for child in node["children"]: + prune_outputs_if_children(child) + + +def log_model_debug_trace(debug_path, model): + if debug_path: + try: + os.makedirs(debug_path, exist_ok=False) + output_path = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree.json") + except Exception as e: + raise ValueError(f"Unexpected or existing debug_path={debug_path}. {e}") + else: + output_path = model._debugger_module_dump_name + "_debug_tree.json" + logger.info(f"Writing model trace at {output_path}") + with open(output_path, "w") as outfile: + prune_outputs_if_children(model._call_tree) + json.dump(model._call_tree, outfile, indent=2) + + +def _attach_debugger_logic(model, class_name, debug_path: str): + # Prepare data structures on the model object + model._call_tree = {"module_path": class_name, "inputs": None, "outputs": None, "children": []} + model._debugger_model_call_stack = [] + model._debugger_module_dump_name = class_name # used for final JSON filename + + def wrap_forward(module, full_path): + orig_forward = module.forward + + @functools.wraps(orig_forward) + def wrapped_forward(*inps, **kws): + if _is_rank_zero(): + dict_inputs = {"args": inps, "kwargs": kws} + dict_inputs = {k: dict_inputs[k] for k in dict_inputs if len(dict_inputs[k]) > 0} + node = { + "module_path": full_path, + "inputs": _serialize_io(dict_inputs), + "outputs": None, + "children": [], + } + model._debugger_model_call_stack.append(node) + with torch.inference_mode(): + out = orig_forward(*inps, **kws) + + if _is_rank_zero(): + if sum(1 for _ in module.named_children()) > 0: + node["outputs"] = None + else: + node["outputs"] = _serialize_io(out) + + finished = model._debugger_model_call_stack.pop() + # prune empty vertices here as well (mostly empty children nodes) + if not finished["children"]: + finished.pop("children") + + if model._debugger_model_call_stack: + model._debugger_model_call_stack[-1]["children"].append(finished) + return out + + module.forward = wrapped_forward + + # wrap all submodules + for name, submodule in model.named_modules(): + if name == "": + continue + wrap_forward(submodule, f"{class_name}.{name}") + + # wrap top-level forward + real_top_forward = model.forward + + @functools.wraps(real_top_forward) + def top_wrapped_forward(*inps, **kws): + if _is_rank_zero(): + top_node = { + "module_path": f"{class_name} (top-level)", + "inputs": _serialize_io({"args": inps, "kwargs": kws}), + "outputs": None, + "children": [], + } + model._debugger_model_call_stack.append(top_node) + + out = real_top_forward(*inps, **kws) + + if _is_rank_zero() and model._debugger_model_call_stack: + top_node["outputs"] = _serialize_io(out) + finished = model._debugger_model_call_stack.pop() + model._call_tree["inputs"] = finished["inputs"] + model._call_tree["outputs"] = finished["outputs"] + model._call_tree["children"] = finished["children"] + # prune empty stuff for visibility + [model._call_tree.pop(k, None) for k in list(model._call_tree.keys()) if not model._call_tree[k]] + + return out + + model.forward = top_wrapped_forward + + # Final hook for writing JSON on forward-end + def final_hook(_, inputs, outputs): + if _is_rank_zero() and model._debugger_model_call_stack: + finished = model._debugger_model_call_stack.pop() + model._call_tree["inputs"] = finished["inputs"] + model._call_tree["outputs"] = finished["outputs"] + model._call_tree["children"] = finished["children"] + + if _is_rank_zero(): + log_model_debug_trace(debug_path=debug_path, model=model) + + model.register_forward_hook(final_hook) + # Optionally also for a couple possible hooks that have specific names. It should be just one. + # This means modules that are not typically called "forward" within the model. But we should not need to recurse + # through them. + possible_model_calls = ["language_model", "model"] + for model_call in possible_model_calls: + this_model_call = getattr(model, model_call, None) + if this_model_call and isinstance(this_model_call, (nn.Module, PreTrainedModel)): + this_model_call.register_forward_hook(final_hook) + break # exit the loop after finding one (unsure, but should be just one call.) + + +@export(backends=("torch",)) +def model_addition_debugger(cls): + """ + # Model addition debugger - a model adder tracer + This decorator is a power user tool intended for model adders. + It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json. + To note, this decorator enforces `torch.inference_mode()`. + ## Usage + + add decorator to your model class + ```python + from ...modeling_utils import model_addition_debugger + + @model_addition_debugger + class MyModel(nn.Module) # Can inherit from PreTrainedModel too + # ... nothing else changes + ``` + Then, in a separate script (example is for Llava) + + ```python + import torch + from PIL import Image + import requests + from transformers import LlavaProcessor, LlavaForConditionalGeneration + torch.random.manual_seed(673) + + # load pretrained model and processor + model_id = "llava-hf/llava-1.5-7b-hf" + processor = LlavaProcessor.from_pretrained(model_id) + model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True) + + # create random image input + random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy()) + + # prompt + prompt = "Describe this image." + + # process inputs + inputs = processor(text=prompt, images=random_image, return_tensors="pt") + + # call forward method (not .generate!) + with torch.no_grad(): + output = model.forward(**inputs) + ``` + + """ + orig_init = cls.__init__ + + @functools.wraps(cls.__init__) + def wrapped_init(self, *args, **kwargs): + orig_init(self, *args, **kwargs) + _attach_debugger_logic(self, cls.__name__) + + cls.__init__ = wrapped_init + return cls + + +@export(backends=("torch",)) +@contextmanager +def model_addition_debugger_context(model, debug_path: str = None): + """ + # Model addition debugger - context manager for model adders + This context manager is a power user tool intended for model adders. + It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json. + To note, this context manager enforces `torch.inference_mode()`. + + ## Usage + + add the context manager to a model to debug + + ```python + import torch + from PIL import Image + import requests + from transformers import LlavaProcessor, LlavaForConditionalGeneration + torch.random.manual_seed(673) + + # load pretrained model and processor + model_id = "llava-hf/llava-1.5-7b-hf" + processor = LlavaProcessor.from_pretrained(model_id) + model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True) + + # create random image input + random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy()) + + # prompt + prompt = "Describe this image." + + # process inputs + inputs = processor(text=prompt, images=random_image, return_tensors="pt") + + # call forward method (not .generate!) + with model_addition_debugger_context(model): + output = model.forward(**inputs) + ``` + + """ + _attach_debugger_logic(model, model.__class__.__name__, debug_path) + try: + yield model + finally: + pass diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index dc48c13706e..85eea3cb100 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -538,6 +538,14 @@ def convert_and_export_with_cache(*args, **kwargs): requires_backends(convert_and_export_with_cache, ["torch"]) +def model_addition_debugger(*args, **kwargs): + requires_backends(model_addition_debugger, ["torch"]) + + +def model_addition_debugger_context(*args, **kwargs): + requires_backends(model_addition_debugger_context, ["torch"]) + + ROPE_INIT_FUNCTIONS = None diff --git a/utils/check_dummies.py b/utils/check_dummies.py index e66d69ada1a..73d7ebbfd1d 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -222,10 +222,23 @@ def check_dummies(overwrite: bool = False): with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f: f.write(dummy_files[backend]) else: + # Temporary fix to help people identify which objects introduced are not correctly protected. + for _actual, _dummy in zip( + actual_dummies["torch"].split("class"), dummy_files["torch"].split("class") + ): + if _actual != _dummy: + actual_broken = _actual + dummy_broken = _dummy + break raise ValueError( "The main __init__ has objects that are not present in " - f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` " - "to fix this." + f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py.\n" + f" It is likely the following objects are responsible, see these excerpts: \n" + f"---------------------------------- Actual -------------------------------------\n" + f" \n {actual_broken} \n" + f"---------------------------------- Dummy -------------------------------------\n" + f" \n {dummy_broken} \n" + "Run `make fix-copies` to fix this." )