mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[XLA] Improve t5 model performance (#18288)
This commit is contained in:
parent
e318cda9ee
commit
d5610b53fa
@ -1331,8 +1331,6 @@ class LongT5PreTrainedModel(PreTrainedModel):
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
||||
|
||||
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
@ -1414,7 +1412,7 @@ class LongT5Stack(LongT5PreTrainedModel):
|
||||
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_attention_mask = torch.ones(
|
||||
|
@ -827,8 +827,6 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
||||
|
||||
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
@ -944,7 +942,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_attention_mask = torch.ones(
|
||||
|
Loading…
Reference in New Issue
Block a user