[XLA] Improve t5 model performance (#18288)

This commit is contained in:
Yanming Wang 2022-07-27 01:44:14 -07:00 committed by GitHub
parent e318cda9ee
commit d5610b53fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 6 deletions

View File

@ -1331,8 +1331,6 @@ class LongT5PreTrainedModel(PreTrainedModel):
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
return shifted_input_ids return shifted_input_ids
@ -1414,7 +1412,7 @@ class LongT5Stack(LongT5PreTrainedModel):
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1] encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones( encoder_attention_mask = torch.ones(

View File

@ -827,8 +827,6 @@ class T5PreTrainedModel(PreTrainedModel):
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
return shifted_input_ids return shifted_input_ids
@ -944,7 +942,7 @@ class T5Stack(T5PreTrainedModel):
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1] encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones( encoder_attention_mask = torch.ones(