mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
holy shit it was just graph breaks
This commit is contained in:
parent
253307a305
commit
a267d8d472
@ -1597,7 +1597,7 @@ class CompileConfig:
|
||||
```
|
||||
"""
|
||||
|
||||
fullgraph: bool = True
|
||||
fullgraph: bool = False
|
||||
dynamic: Optional[bool] = None
|
||||
backend: Union[str, Callable] = "inductor"
|
||||
mode: str = "reduce-overhead"
|
||||
|
@ -963,13 +963,9 @@ def can_return_tuple(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def register_hook_if_needed(layer, capture_outputs):
|
||||
if is_compiling():
|
||||
pass
|
||||
# TorchDynamo is tracing — wrap in disable context
|
||||
else:
|
||||
# Eager mode — no need to disable
|
||||
return layer.register_forward_hook(capture_outputs)
|
||||
return layer.register_forward_hook(capture_outputs)
|
||||
|
||||
|
||||
def check_model_inputs(func):
|
||||
@ -1010,8 +1006,9 @@ def check_model_inputs(func):
|
||||
return capture_fn
|
||||
|
||||
capture_flags = self._can_record_outputs
|
||||
if "kwargs" in all_args and not is_compiling():
|
||||
all_args.update(**all_args["kwargs"])
|
||||
if "kwargs" in all_args:
|
||||
for k, v in all_args["kwargs"].items(): # we do this for dynamo compile
|
||||
all_args[k] = v
|
||||
recordable_keys = {
|
||||
f"output_{k}": all_args.get(f"output_{k}", getattr(self.config, f"output_{k}", False))
|
||||
for k in capture_flags
|
||||
|
Loading…
Reference in New Issue
Block a user