mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Fix: loading DBRX back from saved path (#35728)
* fix dtype as dict for some models + add test * add comment in tests
This commit is contained in:
parent
3613f568cd
commit
b764c20b09
@ -4037,7 +4037,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
sub_config = getattr(config, sub_config_key)
|
sub_config = getattr(config, sub_config_key)
|
||||||
sub_config.torch_dtype = torch_dtype
|
sub_config.torch_dtype = torch_dtype
|
||||||
elif isinstance(torch_dtype, torch.dtype):
|
elif isinstance(torch_dtype, torch.dtype):
|
||||||
pass
|
for sub_config_key in config.sub_configs.keys():
|
||||||
|
sub_config = getattr(config, sub_config_key)
|
||||||
|
sub_config.torch_dtype = torch_dtype
|
||||||
elif isinstance(torch_dtype, dict):
|
elif isinstance(torch_dtype, dict):
|
||||||
for key, curr_dtype in torch_dtype.items():
|
for key, curr_dtype in torch_dtype.items():
|
||||||
if hasattr(config, key):
|
if hasattr(config, key):
|
||||||
|
@ -57,7 +57,7 @@ class DbrxAttentionConfig(PretrainedConfig):
|
|||||||
self.kv_n_heads = kv_n_heads
|
self.kv_n_heads = kv_n_heads
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]:
|
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
|
||||||
if k in kwargs:
|
if k in kwargs:
|
||||||
kwargs.pop(k)
|
kwargs.pop(k)
|
||||||
if len(kwargs) != 0:
|
if len(kwargs) != 0:
|
||||||
@ -109,7 +109,7 @@ class DbrxFFNConfig(PretrainedConfig):
|
|||||||
self.moe_loss_weight = moe_loss_weight
|
self.moe_loss_weight = moe_loss_weight
|
||||||
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
||||||
|
|
||||||
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]:
|
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
|
||||||
if k in kwargs:
|
if k in kwargs:
|
||||||
kwargs.pop(k)
|
kwargs.pop(k)
|
||||||
if len(kwargs) != 0:
|
if len(kwargs) != 0:
|
||||||
|
@ -331,6 +331,12 @@ class ModelTesterMixin:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||||
|
|
||||||
|
# Save and load second time because `from_pretrained` adds a bunch of new config fields
|
||||||
|
# so we need to make sure those fields can be loaded back after saving
|
||||||
|
# Simply init as `model(config)` doesn't add those fields
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
if isinstance(first, tuple) and isinstance(second, tuple):
|
if isinstance(first, tuple) and isinstance(second, tuple):
|
||||||
for tensor1, tensor2 in zip(first, second):
|
for tensor1, tensor2 in zip(first, second):
|
||||||
check_save_load(tensor1, tensor2)
|
check_save_load(tensor1, tensor2)
|
||||||
|
@ -466,13 +466,14 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
def test_model_from_config_torch_dtype_composite(self):
|
def test_model_from_config_torch_dtype_composite(self):
|
||||||
"""
|
"""
|
||||||
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
|
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
|
||||||
|
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
|
||||||
"""
|
"""
|
||||||
# should be able to set torch_dtype as a simple string and the model loads it correctly
|
# should be able to set torch_dtype as a simple string and the model loads it correctly
|
||||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
|
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
|
||||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||||
self.assertEqual(model.vision_tower.dtype, torch.float32)
|
self.assertEqual(model.vision_tower.dtype, torch.float32)
|
||||||
|
|
||||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16")
|
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16)
|
||||||
self.assertEqual(model.language_model.dtype, torch.float16)
|
self.assertEqual(model.language_model.dtype, torch.float16)
|
||||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user