mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Nail in edge case of torch dtype being overriden permantly in the case of an error (#35845)
* Nail in edge case of torch dtype * Rm unused func * Apply suggestions from code review Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Refactor tests to only mock what we need, don't introduce injection functions * SetUp/TearDown * Do super --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
parent
e3458af726
commit
1ce0e2992e
@ -246,6 +246,25 @@ def set_zero3_state():
|
||||
_is_ds_init_called = False
|
||||
|
||||
|
||||
def restore_default_torch_dtype(func):
|
||||
"""
|
||||
Decorator to restore the default torch dtype
|
||||
at the end of the function. Serves
|
||||
as a backup in case calling the function raises
|
||||
an error after the function has changed the default dtype but before it could restore it.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def _wrapper(*args, **kwargs):
|
||||
old_dtype = torch.get_default_dtype()
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
@ -1407,6 +1426,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
self.model_tags.append(tag)
|
||||
|
||||
@classmethod
|
||||
@restore_default_torch_dtype
|
||||
def _from_config(cls, config, **kwargs):
|
||||
"""
|
||||
All context managers that the model should be initialized under go here.
|
||||
@ -3142,6 +3162,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
return super().float(*args)
|
||||
|
||||
@classmethod
|
||||
@restore_default_torch_dtype
|
||||
def from_pretrained(
|
||||
cls: Type[SpecificPreTrainedModelType],
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
|
@ -39,6 +39,7 @@ from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
DynamicCache,
|
||||
LlavaForConditionalGeneration,
|
||||
MistralForCausalLM,
|
||||
OwlViTForObjectDetection,
|
||||
PretrainedConfig,
|
||||
is_torch_available,
|
||||
@ -318,6 +319,14 @@ def check_models_equal(model1, model2):
|
||||
|
||||
@require_torch
|
||||
class ModelUtilsTest(TestCasePlus):
|
||||
def setUp(self):
|
||||
self.old_dtype = torch.get_default_dtype()
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
torch.set_default_dtype(self.old_dtype)
|
||||
super().tearDown()
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "google-bert/bert-base-uncased"
|
||||
@ -1819,6 +1828,67 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertIsNone(model_outputs.past_key_values)
|
||||
self.assertTrue(model.training)
|
||||
|
||||
def test_restore_default_torch_dtype_from_pretrained(self):
|
||||
"""
|
||||
Tests that the default torch dtype is restored
|
||||
when an error happens during the loading of a model.
|
||||
"""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
# set default type to float32
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
# Mock injection point which is right after the call to `_set_default_torch_dtype`
|
||||
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype
|
||||
|
||||
def debug(*args, **kwargs):
|
||||
# call the method as usual, than raise a RuntimeError
|
||||
original_set_default_torch_dtype(*args, **kwargs)
|
||||
raise RuntimeError
|
||||
|
||||
with mock.patch(
|
||||
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype",
|
||||
side_effect=debug,
|
||||
):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto", torch_dtype=torch.float16)
|
||||
# default should still be float32
|
||||
assert torch.get_default_dtype() == torch.float32
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
def test_restore_default_torch_dtype_from_config(self):
|
||||
"""
|
||||
Tests that the default torch dtype is restored
|
||||
when an error happens during the loading of a model.
|
||||
"""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
# set default type to float32
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
TINY_MISTRAL,
|
||||
)
|
||||
|
||||
# Mock injection point which is right after the call to `_set_default_torch_dtype`
|
||||
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype
|
||||
|
||||
def debug(*args, **kwargs):
|
||||
# call the method as usual, than raise a RuntimeError
|
||||
original_set_default_torch_dtype(*args, **kwargs)
|
||||
raise RuntimeError
|
||||
|
||||
with mock.patch(
|
||||
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype",
|
||||
side_effect=debug,
|
||||
):
|
||||
with self.assertRaises(RuntimeError):
|
||||
config.torch_dtype = torch.float16
|
||||
_ = AutoModelForCausalLM.from_config(
|
||||
config,
|
||||
)
|
||||
# default should still be float32
|
||||
assert torch.get_default_dtype() == torch.float32
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
def test_unknown_quantization_config(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = BertConfig(
|
||||
|
Loading…
Reference in New Issue
Block a user