mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
compute seq_len from inputs_embeds (#13128)
This commit is contained in:
parent
e2f07c01e9
commit
14e9d2954c
@ -854,12 +854,12 @@ class ElectraModel(ElectraPreTrainedModel):
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
batch_size, seq_length = input_shape
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
|
Loading…
Reference in New Issue
Block a user