From 95020f208ed7c30895685af60ef3a791fb2d45ff Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 1 Nov 2023 19:25:23 +0100 Subject: [PATCH] Fix CPU offload + disk offload tests (#27204) Fix disk offload tests + weight sharing issues --- src/transformers/modeling_utils.py | 4 +- src/transformers/models/bart/modeling_bart.py | 5 ++ .../modeling_bigbird_pegasus.py | 5 ++ .../models/longt5/modeling_longt5.py | 14 +++++ .../models/m2m_100/modeling_m2m_100.py | 5 ++ .../models/nllb_moe/modeling_nllb_moe.py | 5 ++ .../models/plbart/modeling_plbart.py | 5 ++ .../seamless_m4t/modeling_seamless_m4t.py | 5 ++ .../modeling_switch_transformers.py | 14 +++++ src/transformers/models/t5/modeling_t5.py | 19 +++++++ src/transformers/models/umt5/modeling_umt5.py | 23 ++++++++ tests/models/vitdet/test_modeling_vitdet.py | 6 ++- tests/models/whisper/test_modeling_whisper.py | 6 ++- tests/test_modeling_common.py | 52 +++++++++++++++---- 14 files changed, 155 insertions(+), 13 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 31ed52b4748..e48c98c791b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4576,7 +4576,9 @@ def expand_device_map(device_map, param_names): """ new_device_map = {} for module, device in device_map.items(): - new_device_map.update({p: device for p in param_names if p == module or p.startswith(f"{module}.")}) + new_device_map.update( + {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} + ) return new_device_map diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index c271aabcb4d..60ec557eba7 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1125,6 +1125,11 @@ class BartModel(BartPreTrainedModel): # Initialize weights and apply final processing self.post_init() + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_input_embeddings(self): return self.shared diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index cc69e31e5a7..b9a84a869da 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2312,6 +2312,11 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 91e584d80d3..0ae7cedea00 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1783,6 +1783,11 @@ class LongT5Model(LongT5PreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder @@ -1937,6 +1942,11 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings @@ -2170,6 +2180,10 @@ class LongT5EncoderModel(LongT5PreTrainedModel): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index ed17796d27d..4e5004fc98f 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1103,6 +1103,11 @@ class M2M100Model(M2M100PreTrainedModel): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 1c08a70875e..22708f11252 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1471,6 +1471,11 @@ class NllbMoeModel(NllbMoePreTrainedModel): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 79b5c09cba2..4d8fe161f80 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1084,6 +1084,11 @@ class PLBartModel(PLBartPreTrainedModel): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 3fe519d2d25..62d1f3e21f9 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -4125,6 +4125,11 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel): self.text_decoder.embed_tokens = value self.shared = value + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + @add_start_docstrings_to_model_forward(M4T_MODEL_INPUTS_DOCSTRING) def forward( self, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index e00a0147e42..07c96a5aa82 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1329,6 +1329,11 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder @@ -1505,6 +1510,11 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings @@ -1807,6 +1817,10 @@ class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 3748e5af778..ff8e6609b94 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1414,6 +1414,11 @@ class T5Model(T5PreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder @@ -1620,6 +1625,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings @@ -1920,6 +1930,10 @@ class T5EncoderModel(T5PreTrainedModel): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder @@ -2152,6 +2166,11 @@ class T5ForQuestionAnswering(T5PreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index bfcbfb13eb4..220aff273bc 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -973,6 +973,12 @@ class UMT5Model(UMT5PreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + # Copied from transformers.models.t5.modeling_t5.T5Model._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder def get_encoder(self): return self.encoder @@ -1142,6 +1148,12 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings @@ -1380,6 +1392,11 @@ class UMT5EncoderModel(UMT5PreTrainedModel): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder def get_encoder(self): return self.encoder @@ -1615,6 +1632,12 @@ class UMT5ForQuestionAnswering(UMT5PreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder def get_encoder(self): return self.encoder diff --git a/tests/models/vitdet/test_modeling_vitdet.py b/tests/models/vitdet/test_modeling_vitdet.py index d6ffd03cbd7..361e563d58d 100644 --- a/tests/models/vitdet/test_modeling_vitdet.py +++ b/tests/models/vitdet/test_modeling_vitdet.py @@ -182,7 +182,11 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): # TODO: Fix me (once this model gets more usage) @unittest.skip("Does not work on the tiny model as we keep hitting edge cases.") - def test_disk_offload(self): + def test_disk_offload_bin(self): + super().test_disk_offload() + + @unittest.skip("Does not work on the tiny model as we keep hitting edge cases.") + def test_disk_offload_safetensors(self): super().test_disk_offload() # TODO: Fix me (once this model gets more usage) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index bc1a7bd218c..60a2d3b93ea 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1788,7 +1788,11 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. pass @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.") - def test_disk_offload(self): + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.") + def test_disk_offload_safetensors(self): pass @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3c481007472..595c72cda6f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2578,7 +2578,45 @@ class ModelTesterMixin: @require_accelerate @mark.accelerate_tests @require_torch_gpu - def test_disk_offload(self): + def test_disk_offload_bin(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_class = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).eval() + model = model.to(torch_device) + torch.manual_seed(0) + base_output = model(**inputs_dict_class) + + model_size = compute_module_sizes(model)[""] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir, safe_serialization=False) + + with self.assertRaises(ValueError): + max_size = int(self.model_split_percents[0] * model_size) + max_memory = {0: max_size, "cpu": max_size} + # This errors out cause it's missing an offload folder + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + max_size = int(self.model_split_percents[1] * model_size) + max_memory = {0: max_size, "cpu": max_size} + new_model = model_class.from_pretrained( + tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir + ) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict_class) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + + @require_accelerate + @mark.accelerate_tests + @require_torch_gpu + def test_disk_offload_safetensors(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: @@ -2595,17 +2633,11 @@ class ModelTesterMixin: with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) - with self.assertRaises(ValueError): - max_size = int(self.model_split_percents[0] * model_size) - max_memory = {0: max_size, "cpu": max_size} - # This errors out cause it's missing an offload folder - new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) - max_size = int(self.model_split_percents[1] * model_size) max_memory = {0: max_size, "cpu": max_size} - new_model = model_class.from_pretrained( - tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir - ) + + # This doesn't error out as it's in safetensors and doesn't need an offload folder + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0)