Add a getattr method, which replaces _module_getattr in torch.fx.Tracer from PyTorch 1.13+ (#19233)

This commit is contained in:
Michael Benayoun 2022-09-29 11:04:49 +02:00 committed by GitHub
parent 9d732fd2dd
commit bb6fa06f2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -862,11 +862,12 @@ class HFTracer(Tracer):
return rv
# Replaced by .getattr from PyTorch 1.13
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
@ -899,6 +900,10 @@ class HFTracer(Tracer):
return attr_val
# Needed for PyTorch 1.13+
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)