Nail in edge case of torch dtype being overriden permantly in the case of an error (#35845)

* Nail in edge case of torch dtype

* Rm unused func

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Refactor tests to only mock what we need, don't introduce injection functions

* SetUp/TearDown

* Do super

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
Zach Mueller 2025-02-06 09:05:23 -05:00 committed by GitHub
parent e3458af726
commit 1ce0e2992e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 91 additions and 0 deletions

View File

@ -246,6 +246,25 @@ def set_zero3_state():
_is_ds_init_called = False
def restore_default_torch_dtype(func):
"""
Decorator to restore the default torch dtype
at the end of the function. Serves
as a backup in case calling the function raises
an error after the function has changed the default dtype but before it could restore it.
"""
@wraps(func)
def _wrapper(*args, **kwargs):
old_dtype = torch.get_default_dtype()
try:
return func(*args, **kwargs)
finally:
torch.set_default_dtype(old_dtype)
return _wrapper
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
@ -1407,6 +1426,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.model_tags.append(tag)
@classmethod
@restore_default_torch_dtype
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
@ -3142,6 +3162,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return super().float(*args)
@classmethod
@restore_default_torch_dtype
def from_pretrained(
cls: Type[SpecificPreTrainedModelType],
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],

View File

@ -39,6 +39,7 @@ from transformers import (
AutoModelForSequenceClassification,
DynamicCache,
LlavaForConditionalGeneration,
MistralForCausalLM,
OwlViTForObjectDetection,
PretrainedConfig,
is_torch_available,
@ -318,6 +319,14 @@ def check_models_equal(model1, model2):
@require_torch
class ModelUtilsTest(TestCasePlus):
def setUp(self):
self.old_dtype = torch.get_default_dtype()
super().setUp()
def tearDown(self):
torch.set_default_dtype(self.old_dtype)
super().tearDown()
@slow
def test_model_from_pretrained(self):
model_name = "google-bert/bert-base-uncased"
@ -1819,6 +1828,67 @@ class ModelUtilsTest(TestCasePlus):
self.assertIsNone(model_outputs.past_key_values)
self.assertTrue(model.training)
def test_restore_default_torch_dtype_from_pretrained(self):
"""
Tests that the default torch dtype is restored
when an error happens during the loading of a model.
"""
old_dtype = torch.get_default_dtype()
# set default type to float32
torch.set_default_dtype(torch.float32)
# Mock injection point which is right after the call to `_set_default_torch_dtype`
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype
def debug(*args, **kwargs):
# call the method as usual, than raise a RuntimeError
original_set_default_torch_dtype(*args, **kwargs)
raise RuntimeError
with mock.patch(
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype",
side_effect=debug,
):
with self.assertRaises(RuntimeError):
_ = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto", torch_dtype=torch.float16)
# default should still be float32
assert torch.get_default_dtype() == torch.float32
torch.set_default_dtype(old_dtype)
def test_restore_default_torch_dtype_from_config(self):
"""
Tests that the default torch dtype is restored
when an error happens during the loading of a model.
"""
old_dtype = torch.get_default_dtype()
# set default type to float32
torch.set_default_dtype(torch.float32)
config = AutoConfig.from_pretrained(
TINY_MISTRAL,
)
# Mock injection point which is right after the call to `_set_default_torch_dtype`
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype
def debug(*args, **kwargs):
# call the method as usual, than raise a RuntimeError
original_set_default_torch_dtype(*args, **kwargs)
raise RuntimeError
with mock.patch(
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype",
side_effect=debug,
):
with self.assertRaises(RuntimeError):
config.torch_dtype = torch.float16
_ = AutoModelForCausalLM.from_config(
config,
)
# default should still be float32
assert torch.get_default_dtype() == torch.float32
torch.set_default_dtype(old_dtype)
def test_unknown_quantization_config(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = BertConfig(