compute seq_len from inputs_embeds (#13128)

This commit is contained in:
sararb 2021-08-16 12:36:08 -04:00 committed by GitHub
parent e2f07c01e9
commit 14e9d2954c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -854,12 +854,12 @@ class ElectraModel(ElectraPreTrainedModel):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None: