mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix CPU offload + disk offload tests (#27204)
Fix disk offload tests + weight sharing issues
This commit is contained in:
parent
c9e72f55b2
commit
95020f208e
@ -4576,7 +4576,9 @@ def expand_device_map(device_map, param_names):
|
||||
"""
|
||||
new_device_map = {}
|
||||
for module, device in device_map.items():
|
||||
new_device_map.update({p: device for p in param_names if p == module or p.startswith(f"{module}.")})
|
||||
new_device_map.update(
|
||||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
||||
)
|
||||
return new_device_map
|
||||
|
||||
|
||||
|
@ -1125,6 +1125,11 @@ class BartModel(BartPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
|
@ -2312,6 +2312,11 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
@ -1783,6 +1783,11 @@ class LongT5Model(LongT5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
@ -1937,6 +1942,11 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@ -2170,6 +2180,10 @@ class LongT5EncoderModel(LongT5PreTrainedModel):
|
||||
self.shared = new_embeddings
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
@ -1103,6 +1103,11 @@ class M2M100Model(M2M100PreTrainedModel):
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
@ -1471,6 +1471,11 @@ class NllbMoeModel(NllbMoePreTrainedModel):
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
@ -1084,6 +1084,11 @@ class PLBartModel(PLBartPreTrainedModel):
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
@ -4125,6 +4125,11 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel):
|
||||
self.text_decoder.embed_tokens = value
|
||||
self.shared = value
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
|
||||
|
||||
@add_start_docstrings_to_model_forward(M4T_MODEL_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -1329,6 +1329,11 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
@ -1505,6 +1510,11 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@ -1807,6 +1817,10 @@ class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel):
|
||||
self.shared = new_embeddings
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
@ -1414,6 +1414,11 @@ class T5Model(T5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
@ -1620,6 +1625,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@ -1920,6 +1930,10 @@ class T5EncoderModel(T5PreTrainedModel):
|
||||
self.shared = new_embeddings
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
@ -2152,6 +2166,11 @@ class T5ForQuestionAnswering(T5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
@ -973,6 +973,12 @@ class UMT5Model(UMT5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5Model._tie_weights
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
@ -1142,6 +1148,12 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._tie_weights
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
@ -1380,6 +1392,11 @@ class UMT5EncoderModel(UMT5PreTrainedModel):
|
||||
self.shared = new_embeddings
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel._tie_weights
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
@ -1615,6 +1632,12 @@ class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering._tie_weights
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
@ -182,7 +182,11 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO: Fix me (once this model gets more usage)
|
||||
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
|
||||
def test_disk_offload(self):
|
||||
def test_disk_offload_bin(self):
|
||||
super().test_disk_offload()
|
||||
|
||||
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
|
||||
def test_disk_offload_safetensors(self):
|
||||
super().test_disk_offload()
|
||||
|
||||
# TODO: Fix me (once this model gets more usage)
|
||||
|
@ -1788,7 +1788,11 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
|
||||
def test_disk_offload(self):
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
|
||||
|
@ -2578,7 +2578,45 @@ class ModelTesterMixin:
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@require_torch_gpu
|
||||
def test_disk_offload(self):
|
||||
def test_disk_offload_bin(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class._no_split_modules is None:
|
||||
continue
|
||||
|
||||
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config).eval()
|
||||
model = model.to(torch_device)
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict_class)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
# This errors out cause it's missing an offload folder
|
||||
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
max_size = int(self.model_split_percents[1] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
|
||||
)
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@require_torch_gpu
|
||||
def test_disk_offload_safetensors(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@ -2595,17 +2633,11 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
# This errors out cause it's missing an offload folder
|
||||
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
max_size = int(self.model_split_percents[1] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
|
||||
)
|
||||
|
||||
# This doesn't error out as it's in safetensors and doesn't need an offload folder
|
||||
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
|
Loading…
Reference in New Issue
Block a user