From d56d723fad9ee39dea10f6d98008e7ccac243e08 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 27 Oct 2022 18:06:55 +0200 Subject: [PATCH] Add `accelerate` support for M2M100 (#19912) * add `accelerate` support for M2M100 * fix device set nit --- .../models/m2m_100/modeling_m2m_100.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 1d5f9d5eadc..497270bd2ac 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -532,6 +532,7 @@ class M2M100PreTrainedModel(PreTrainedModel): config_class = M2M100Config base_model_prefix = "model" supports_gradient_checkpointing = True + _no_split_modules = ["M2M100Attention"] def _init_weights(self, module): std = self.config.init_std @@ -693,10 +694,10 @@ class M2M100Encoder(M2M100PreTrainedModel): self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight self.embed_positions = M2M100SinusoidalPositionalEmbedding( config.max_position_embeddings, @@ -777,6 +778,7 @@ class M2M100Encoder(M2M100PreTrainedModel): inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_ids, inputs_embeds) + embed_pos = embed_pos.to(inputs_embeds.device) hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -868,10 +870,10 @@ class M2M100Decoder(M2M100PreTrainedModel): self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight self.embed_positions = M2M100SinusoidalPositionalEmbedding( config.max_position_embeddings, @@ -1010,6 +1012,7 @@ class M2M100Decoder(M2M100PreTrainedModel): # embed positions positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) + positions = positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions