diff --git a/src/transformers/modeling_fx_utils.py b/src/transformers/modeling_fx_utils.py index 1bad3e4ec7a..e9cdf00ce89 100644 --- a/src/transformers/modeling_fx_utils.py +++ b/src/transformers/modeling_fx_utils.py @@ -68,6 +68,8 @@ class HFProxy(Proxy): if self.tracer.num_choices <= 0: raise ValueError("num_choices must be given to the CustomTracer for MultipleChoice tasks.") shape = shape[:1] + [self.tracer.num_choices] + shape[1:] + elif "hidden_states.s" in code_context: + shape = shape + [self.tracer.root.config.hidden_size] else: # Default case: # - If self.size is called for an unpacking, retrieves the corresponding unpacking