fixed shape issue for T5 tracing (#11742)

Co-authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
Michael Benayoun 2021-05-17 12:17:31 +02:00 committed by GitHub
parent 0fc56df5fb
commit a0531c8a24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -68,6 +68,8 @@ class HFProxy(Proxy):
if self.tracer.num_choices <= 0:
raise ValueError("num_choices must be given to the CustomTracer for MultipleChoice tasks.")
shape = shape[:1] + [self.tracer.num_choices] + shape[1:]
elif "hidden_states.s" in code_context:
shape = shape + [self.tracer.root.config.hidden_size]
else:
# Default case:
# - If self.size is called for an unpacking, retrieves the corresponding unpacking