holy shit it was just graph breaks

This commit is contained in:
Arthur 2025-07-02 12:17:30 +02:00
parent 253307a305
commit a267d8d472
2 changed files with 6 additions and 9 deletions

View File

@ -1597,7 +1597,7 @@ class CompileConfig:
```
"""
fullgraph: bool = True
fullgraph: bool = False
dynamic: Optional[bool] = None
backend: Union[str, Callable] = "inductor"
mode: str = "reduce-overhead"

View File

@ -963,13 +963,9 @@ def can_return_tuple(func):
return wrapper
@torch._dynamo.disable
def register_hook_if_needed(layer, capture_outputs):
if is_compiling():
pass
# TorchDynamo is tracing — wrap in disable context
else:
# Eager mode — no need to disable
return layer.register_forward_hook(capture_outputs)
return layer.register_forward_hook(capture_outputs)
def check_model_inputs(func):
@ -1010,8 +1006,9 @@ def check_model_inputs(func):
return capture_fn
capture_flags = self._can_record_outputs
if "kwargs" in all_args and not is_compiling():
all_args.update(**all_args["kwargs"])
if "kwargs" in all_args:
for k, v in all_args["kwargs"].items(): # we do this for dynamo compile
all_args[k] = v
recordable_keys = {
f"output_{k}": all_args.get(f"output_{k}", getattr(self.config, f"output_{k}", False))
for k in capture_flags