mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix generation config for empty state dict (#21630)
This commit is contained in:
parent
317282927d
commit
d4ba6e1a0e
@ -2648,7 +2648,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
_from_pipeline=from_pipeline,
|
_from_pipeline=from_pipeline,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
except OSError:
|
except (OSError, TypeError):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Generation config file not found, using a generation config created from the model config."
|
"Generation config file not found, using a generation config created from the model config."
|
||||||
)
|
)
|
||||||
|
@ -325,6 +325,18 @@ class ModelTesterMixin:
|
|||||||
else:
|
else:
|
||||||
check_save_load(first, second)
|
check_save_load(first, second)
|
||||||
|
|
||||||
|
def test_from_pretrained_no_checkpoint(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
new_model = model_class.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
||||||
|
)
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
def test_save_load_keys_to_ignore_on_save(self):
|
def test_save_load_keys_to_ignore_on_save(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@ -2776,15 +2788,6 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
BertModel.from_pretrained(TINY_T5)
|
BertModel.from_pretrained(TINY_T5)
|
||||||
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
||||||
|
|
||||||
def test_model_from_pretrained_no_checkpoint(self):
|
|
||||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
model = BertModel(config)
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
|
|
||||||
new_model = BertModel.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=state_dict)
|
|
||||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
|
||||||
self.assertTrue(torch.equal(p1, p2))
|
|
||||||
|
|
||||||
def test_model_from_config_torch_dtype(self):
|
def test_model_from_config_torch_dtype(self):
|
||||||
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
||||||
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
|
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
|
||||||
|
Loading…
Reference in New Issue
Block a user