From ff20f9cf3615a8638023bc82925573cb9d0f3560 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 23 Mar 2023 10:34:43 +0100 Subject: [PATCH] [`MBart`] Add `accelerate` support for MBart (#22309) add `accelerate` support for MBart --- src/transformers/models/mbart/modeling_mbart.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 55c381a530b..0bf3fb62f76 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -502,6 +502,7 @@ class MBartPreTrainedModel(PreTrainedModel): config_class = MBartConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _no_split_modules = ["MBartDecoderLayer", "MBartAttention"] def _init_weights(self, module): std = self.config.init_std @@ -702,10 +703,10 @@ class MBartEncoder(MBartPreTrainedModel): self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight self.embed_positions = MBartLearnedPositionalEmbedding( config.max_position_embeddings, @@ -793,7 +794,7 @@ class MBartEncoder(MBartPreTrainedModel): embed_pos = self.embed_positions(input) - hidden_states = inputs_embeds + embed_pos + hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device) hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -876,10 +877,10 @@ class MBartDecoder(MBartPreTrainedModel): self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight self.embed_positions = MBartLearnedPositionalEmbedding( config.max_position_embeddings, @@ -1038,7 +1039,7 @@ class MBartDecoder(MBartPreTrainedModel): # embed positions positions = self.embed_positions(input, past_key_values_length) - hidden_states = inputs_embeds + positions + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)