Add accelerate support for LongT5 models (#20341)

*  add accelerate support for LongT5 models

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

* fix `accelerate` tests

* Trigger CI test

Signed-off-by: peter szemraj <peterszemraj@gmail.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
This commit is contained in:
Peter 2022-12-12 15:25:52 +01:00 committed by GitHub
parent 8286af6f54
commit a3345c1f13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -648,9 +648,12 @@ class LongT5LocalAttention(nn.Module):
def compute_bias(self, block_length: int):
"""Compute binned relative position bias"""
memory_position = torch.arange(
3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
target_device = (
self.relative_attention_bias.weight.device
if self.relative_attention_bias.weight.device.type != "meta"
else None
)
memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
context_position = memory_position[block_length:-block_length]
# (block_length, 3 * block_length)
@ -843,9 +846,12 @@ class LongT5TransientGlobalAttention(nn.Module):
def compute_bias(self, block_length: int):
"""Compute binned relative position bias"""
memory_position = torch.arange(
3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
target_device = (
self.relative_attention_bias.weight.device
if self.relative_attention_bias.weight.device.type != "meta"
else None
)
memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
context_position = memory_position[block_length:-block_length]
# (block_length, 3 * block_length)
@ -1271,6 +1277,7 @@ class LongT5PreTrainedModel(PreTrainedModel):
config_class = LongT5Config
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["LongT5Block"]
@property
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
@ -1366,7 +1373,9 @@ class LongT5Stack(LongT5PreTrainedModel):
def __init__(self, config, embed_tokens=None):
super().__init__(config)
self.embed_tokens = embed_tokens
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.is_decoder = config.is_decoder
self.local_radius = config.local_radius