diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 254e4c9fc46..9c99a6d770d 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4037,7 +4037,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix sub_config = getattr(config, sub_config_key) sub_config.torch_dtype = torch_dtype elif isinstance(torch_dtype, torch.dtype): - pass + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.torch_dtype = torch_dtype elif isinstance(torch_dtype, dict): for key, curr_dtype in torch_dtype.items(): if hasattr(config, key): diff --git a/src/transformers/models/dbrx/configuration_dbrx.py b/src/transformers/models/dbrx/configuration_dbrx.py index 7935b1d1beb..72df1fe335b 100644 --- a/src/transformers/models/dbrx/configuration_dbrx.py +++ b/src/transformers/models/dbrx/configuration_dbrx.py @@ -57,7 +57,7 @@ class DbrxAttentionConfig(PretrainedConfig): self.kv_n_heads = kv_n_heads self.rope_theta = rope_theta - for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]: + for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]: if k in kwargs: kwargs.pop(k) if len(kwargs) != 0: @@ -109,7 +109,7 @@ class DbrxFFNConfig(PretrainedConfig): self.moe_loss_weight = moe_loss_weight self.moe_normalize_expert_weights = moe_normalize_expert_weights - for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]: + for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]: if k in kwargs: kwargs.pop(k) if len(kwargs) != 0: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 98709ba3b84..1d52e828144 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -331,6 +331,12 @@ class ModelTesterMixin: with torch.no_grad(): second = model(**self._prepare_for_class(inputs_dict, model_class))[0] + # Save and load second time because `from_pretrained` adds a bunch of new config fields + # so we need to make sure those fields can be loaded back after saving + # Simply init as `model(config)` doesn't add those fields + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname) + if isinstance(first, tuple) and isinstance(second, tuple): for tensor1, tensor2 in zip(first, second): check_save_load(tensor1, tensor2) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index dfc31100509..dd52927a250 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -466,13 +466,14 @@ class ModelUtilsTest(TestCasePlus): def test_model_from_config_torch_dtype_composite(self): """ Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config + Tiny-Llava has saved auto dtype as `torch.float32` for all modules. """ # should be able to set torch_dtype as a simple string and the model loads it correctly model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32") self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.vision_tower.dtype, torch.float32) - model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16") + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16) self.assertEqual(model.language_model.dtype, torch.float16) self.assertEqual(model.vision_tower.dtype, torch.float16)