mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add accelerate
support for LongT5 models (#20341)
* ✨ add accelerate support for LongT5 models
Signed-off-by: peter szemraj <peterszemraj@gmail.com>
* fix `accelerate` tests
* Trigger CI test
Signed-off-by: peter szemraj <peterszemraj@gmail.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
This commit is contained in:
parent
8286af6f54
commit
a3345c1f13
@ -648,9 +648,12 @@ class LongT5LocalAttention(nn.Module):
|
||||
|
||||
def compute_bias(self, block_length: int):
|
||||
"""Compute binned relative position bias"""
|
||||
memory_position = torch.arange(
|
||||
3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
|
||||
target_device = (
|
||||
self.relative_attention_bias.weight.device
|
||||
if self.relative_attention_bias.weight.device.type != "meta"
|
||||
else None
|
||||
)
|
||||
memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
|
||||
context_position = memory_position[block_length:-block_length]
|
||||
|
||||
# (block_length, 3 * block_length)
|
||||
@ -843,9 +846,12 @@ class LongT5TransientGlobalAttention(nn.Module):
|
||||
|
||||
def compute_bias(self, block_length: int):
|
||||
"""Compute binned relative position bias"""
|
||||
memory_position = torch.arange(
|
||||
3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
|
||||
target_device = (
|
||||
self.relative_attention_bias.weight.device
|
||||
if self.relative_attention_bias.weight.device.type != "meta"
|
||||
else None
|
||||
)
|
||||
memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
|
||||
context_position = memory_position[block_length:-block_length]
|
||||
|
||||
# (block_length, 3 * block_length)
|
||||
@ -1271,6 +1277,7 @@ class LongT5PreTrainedModel(PreTrainedModel):
|
||||
config_class = LongT5Config
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LongT5Block"]
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
|
||||
@ -1366,7 +1373,9 @@ class LongT5Stack(LongT5PreTrainedModel):
|
||||
def __init__(self, config, embed_tokens=None):
|
||||
super().__init__(config)
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
self.local_radius = config.local_radius
|
||||
|
Loading…
Reference in New Issue
Block a user