mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[MBart
] Add accelerate
support for MBart (#22309)
add `accelerate` support for MBart
This commit is contained in:
parent
61f79b2986
commit
ff20f9cf36
@ -502,6 +502,7 @@ class MBartPreTrainedModel(PreTrainedModel):
|
||||
config_class = MBartConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@ -702,10 +703,10 @@ class MBartEncoder(MBartPreTrainedModel):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = MBartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -793,7 +794,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
||||
|
||||
embed_pos = self.embed_positions(input)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device)
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
@ -876,10 +877,10 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = MBartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -1038,7 +1039,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
# embed positions
|
||||
positions = self.embed_positions(input, past_key_values_length)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
Loading…
Reference in New Issue
Block a user