Fix FlauBERT GPU test (#6142)

* Fix GPU test

* Remove legacy constructor
This commit is contained in:
Lysandre Debut 2020-07-30 11:11:48 -04:00 committed by GitHub
parent 91cb95461e
commit ec0267475c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -163,11 +163,13 @@ class FlaubertModel(XLMModel):
else:
bs, slen = inputs_embeds.size()[:-1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if lengths is None:
if input_ids is not None:
lengths = (input_ids != self.pad_index).sum(dim=1).long()
else:
lengths = torch.LongTensor([slen] * bs)
lengths = torch.tensor([slen] * bs, device=device)
# mask = input_ids != self.pad_index
# check inputs
@ -184,8 +186,6 @@ class FlaubertModel(XLMModel):
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
device = input_ids.device if input_ids is not None else inputs_embeds.device
# position_ids
if position_ids is None:
position_ids = torch.arange(slen, dtype=torch.long, device=device)