diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index d3255baf847..9fd66df175f 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -862,11 +862,12 @@ class HFTracer(Tracer): return rv + # Replaced by .getattr from PyTorch 1.13 def _module_getattr(self, attr, attr_val, parameter_proxy_cache): if getattr(self, "_disable_module_getattr", False): return attr_val else: - # return super()._module_getattr(attr, attr_val, parameter_proxy_cache) + def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): for n, p in collection_to_search: if attr_val is p: @@ -899,6 +900,10 @@ class HFTracer(Tracer): return attr_val + # Needed for PyTorch 1.13+ + def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): + return self._module_getattr(attr, attr_val, parameter_proxy_cache) + def call_module(self, m, forward, args, kwargs): self.orig_forward = forward return super().call_module(m, forward, args, kwargs)