mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add a getattr method, which replaces _module_getattr in torch.fx.Tracer from PyTorch 1.13+ (#19233)
This commit is contained in:
parent
9d732fd2dd
commit
bb6fa06f2d
@ -862,11 +862,12 @@ class HFTracer(Tracer):
|
||||
|
||||
return rv
|
||||
|
||||
# Replaced by .getattr from PyTorch 1.13
|
||||
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||
if getattr(self, "_disable_module_getattr", False):
|
||||
return attr_val
|
||||
else:
|
||||
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
|
||||
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
|
||||
for n, p in collection_to_search:
|
||||
if attr_val is p:
|
||||
@ -899,6 +900,10 @@ class HFTracer(Tracer):
|
||||
|
||||
return attr_val
|
||||
|
||||
# Needed for PyTorch 1.13+
|
||||
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
|
||||
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
|
||||
def call_module(self, m, forward, args, kwargs):
|
||||
self.orig_forward = forward
|
||||
return super().call_module(m, forward, args, kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user