mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix FlauBERT GPU test (#6142)
* Fix GPU test * Remove legacy constructor
This commit is contained in:
parent
91cb95461e
commit
ec0267475c
@ -163,11 +163,13 @@ class FlaubertModel(XLMModel):
|
||||
else:
|
||||
bs, slen = inputs_embeds.size()[:-1]
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if lengths is None:
|
||||
if input_ids is not None:
|
||||
lengths = (input_ids != self.pad_index).sum(dim=1).long()
|
||||
else:
|
||||
lengths = torch.LongTensor([slen] * bs)
|
||||
lengths = torch.tensor([slen] * bs, device=device)
|
||||
# mask = input_ids != self.pad_index
|
||||
|
||||
# check inputs
|
||||
@ -184,8 +186,6 @@ class FlaubertModel(XLMModel):
|
||||
# if self.is_decoder and src_enc is not None:
|
||||
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# position_ids
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(slen, dtype=torch.long, device=device)
|
||||
|
Loading…
Reference in New Issue
Block a user