Fix CPU offload + disk offload tests (#27204)

Fix disk offload tests + weight sharing issues
This commit is contained in:
Lysandre Debut 2023-11-01 19:25:23 +01:00 committed by GitHub
parent c9e72f55b2
commit 95020f208e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 155 additions and 13 deletions

View File

@ -4576,7 +4576,9 @@ def expand_device_map(device_map, param_names):
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update({p: device for p in param_names if p == module or p.startswith(f"{module}.")})
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map

View File

@ -1125,6 +1125,11 @@ class BartModel(BartPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_input_embeddings(self):
return self.shared

View File

@ -2312,6 +2312,11 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder

View File

@ -1783,6 +1783,11 @@ class LongT5Model(LongT5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder
@ -1937,6 +1942,11 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@ -2170,6 +2180,10 @@ class LongT5EncoderModel(LongT5PreTrainedModel):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder

View File

@ -1103,6 +1103,11 @@ class M2M100Model(M2M100PreTrainedModel):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder

View File

@ -1471,6 +1471,11 @@ class NllbMoeModel(NllbMoePreTrainedModel):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder

View File

@ -1084,6 +1084,11 @@ class PLBartModel(PLBartPreTrainedModel):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder

View File

@ -4125,6 +4125,11 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel):
self.text_decoder.embed_tokens = value
self.shared = value
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
@add_start_docstrings_to_model_forward(M4T_MODEL_INPUTS_DOCSTRING)
def forward(
self,

View File

@ -1329,6 +1329,11 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder
@ -1505,6 +1510,11 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@ -1807,6 +1817,10 @@ class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder

View File

@ -1414,6 +1414,11 @@ class T5Model(T5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder
@ -1620,6 +1625,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@ -1920,6 +1930,10 @@ class T5EncoderModel(T5PreTrainedModel):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder
@ -2152,6 +2166,11 @@ class T5ForQuestionAnswering(T5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder

View File

@ -973,6 +973,12 @@ class UMT5Model(UMT5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
# Copied from transformers.models.t5.modeling_t5.T5Model._tie_weights
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
def get_encoder(self):
return self.encoder
@ -1142,6 +1148,12 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._tie_weights
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@ -1380,6 +1392,11 @@ class UMT5EncoderModel(UMT5PreTrainedModel):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel._tie_weights
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder
def get_encoder(self):
return self.encoder
@ -1615,6 +1632,12 @@ class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering._tie_weights
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder
def get_encoder(self):
return self.encoder

View File

@ -182,7 +182,11 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix me (once this model gets more usage)
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
def test_disk_offload(self):
def test_disk_offload_bin(self):
super().test_disk_offload()
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
def test_disk_offload_safetensors(self):
super().test_disk_offload()
# TODO: Fix me (once this model gets more usage)

View File

@ -1788,7 +1788,11 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
pass
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
def test_disk_offload(self):
def test_disk_offload_bin(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
def test_disk_offload_safetensors(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")

View File

@ -2578,7 +2578,45 @@ class ModelTesterMixin:
@require_accelerate
@mark.accelerate_tests
@require_torch_gpu
def test_disk_offload(self):
def test_disk_offload_bin(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class._no_split_modules is None:
continue
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
with self.assertRaises(ValueError):
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
# This errors out cause it's missing an offload folder
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
max_size = int(self.model_split_percents[1] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
)
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_accelerate
@mark.accelerate_tests
@require_torch_gpu
def test_disk_offload_safetensors(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
@ -2595,17 +2633,11 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
with self.assertRaises(ValueError):
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
# This errors out cause it's missing an offload folder
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
max_size = int(self.model_split_percents[1] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
)
# This doesn't error out as it's in safetensors and doesn't need an offload folder
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)