diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4e27421574a..2c28ec8f1e5 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1252,13 +1252,13 @@ def _get_torch_dtype( for key, curr_dtype in torch_dtype.items(): if hasattr(config, key): value = getattr(config, key) + curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype) value.torch_dtype = curr_dtype # main torch dtype for modules that aren't part of any sub-config torch_dtype = torch_dtype.get("") + torch_dtype = torch_dtype if not isinstance(torch_dtype, str) else getattr(torch, torch_dtype) config.torch_dtype = torch_dtype - if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): - torch_dtype = getattr(torch, torch_dtype) - elif torch_dtype is None: + if torch_dtype is None: torch_dtype = torch.float32 else: raise ValueError( @@ -1269,7 +1269,7 @@ def _get_torch_dtype( dtype_orig = cls._set_default_torch_dtype(torch_dtype) else: # set fp32 as the default dtype for BC - default_dtype = str(torch.get_default_dtype()).split(".")[-1] + default_dtype = torch.get_default_dtype() config.torch_dtype = default_dtype for key in config.sub_configs.keys(): value = getattr(config, key) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 71a400579f6..96cbd4f77bb 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -482,9 +482,11 @@ class ModelUtilsTest(TestCasePlus): # test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32") self.assertEqual(model.dtype, torch.float32) + self.assertIsInstance(model.config.torch_dtype, torch.dtype) model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16") self.assertEqual(model.dtype, torch.float16) + self.assertIsInstance(model.config.torch_dtype, torch.dtype) # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type with self.assertRaises(ValueError): @@ -495,14 +497,22 @@ class ModelUtilsTest(TestCasePlus): Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config Tiny-Llava has saved auto dtype as `torch.float32` for all modules. """ + # Load without dtype specified + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA) + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.float32) + self.assertIsInstance(model.config.torch_dtype, torch.dtype) + # should be able to set torch_dtype as a simple string and the model loads it correctly model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32") self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.vision_tower.dtype, torch.float32) + self.assertIsInstance(model.config.torch_dtype, torch.dtype) model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16) self.assertEqual(model.language_model.dtype, torch.float16) self.assertEqual(model.vision_tower.dtype, torch.float16) + self.assertIsInstance(model.config.torch_dtype, torch.dtype) # should be able to set torch_dtype as a dict for each sub-config model = LlavaForConditionalGeneration.from_pretrained( @@ -511,6 +521,7 @@ class ModelUtilsTest(TestCasePlus): self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.vision_tower.dtype, torch.float16) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) + self.assertIsInstance(model.config.torch_dtype, torch.dtype) # should be able to set the values as torch.dtype (not str) model = LlavaForConditionalGeneration.from_pretrained( @@ -519,6 +530,7 @@ class ModelUtilsTest(TestCasePlus): self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.vision_tower.dtype, torch.float16) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) + self.assertIsInstance(model.config.torch_dtype, torch.dtype) # should be able to set the values in configs directly and pass it to `from_pretrained` config = copy.deepcopy(model.config) @@ -529,6 +541,7 @@ class ModelUtilsTest(TestCasePlus): self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.vision_tower.dtype, torch.bfloat16) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16) + self.assertIsInstance(model.config.torch_dtype, torch.dtype) # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"] @@ -536,6 +549,7 @@ class ModelUtilsTest(TestCasePlus): self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.vision_tower.dtype, torch.bfloat16) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32) + self.assertIsInstance(model.config.torch_dtype, torch.dtype) # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type with self.assertRaises(ValueError):