mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[XLA] Improve t5 model performance (#18288)
This commit is contained in:
parent
e318cda9ee
commit
d5610b53fa
@ -1331,8 +1331,6 @@ class LongT5PreTrainedModel(PreTrainedModel):
|
|||||||
# replace possible -100 values in labels by `pad_token_id`
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
||||||
|
|
||||||
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
|
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
|
|
||||||
@ -1414,7 +1412,7 @@ class LongT5Stack(LongT5PreTrainedModel):
|
|||||||
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
|
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||||
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||||
encoder_seq_length = encoder_hidden_states.shape[1]
|
encoder_seq_length = encoder_hidden_states.shape[1]
|
||||||
encoder_attention_mask = torch.ones(
|
encoder_attention_mask = torch.ones(
|
||||||
|
@ -827,8 +827,6 @@ class T5PreTrainedModel(PreTrainedModel):
|
|||||||
# replace possible -100 values in labels by `pad_token_id`
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
||||||
|
|
||||||
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
|
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
|
|
||||||
@ -944,7 +942,7 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
|
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||||
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||||
encoder_seq_length = encoder_hidden_states.shape[1]
|
encoder_seq_length = encoder_hidden_states.shape[1]
|
||||||
encoder_attention_mask = torch.ones(
|
encoder_attention_mask = torch.ones(
|
||||||
|
Loading…
Reference in New Issue
Block a user