mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix from_pretrained with corrupted state_dict (#12939)
* Fix from_pretrained with corrupted state_dict * Adapt test * Use better checkpoint * Style * Clean up
This commit is contained in:
parent
a28da4c490
commit
d4c834d2e0
@ -1409,6 +1409,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
add_prefix = has_prefix_module and not expects_prefix_module
|
||||
|
||||
if remove_prefix:
|
||||
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
|
||||
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
|
||||
elif add_prefix:
|
||||
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
||||
@ -1490,6 +1491,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
start_prefix = cls.base_model_prefix + "."
|
||||
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
||||
model_to_load = getattr(model, cls.base_model_prefix)
|
||||
if any(key in expected_keys_not_prefixed for key in loaded_keys):
|
||||
raise ValueError(
|
||||
"The state dictionary of the model you are training to load is corrupted. Are you sure it was "
|
||||
"properly saved?"
|
||||
)
|
||||
|
||||
load(model_to_load, prefix=start_prefix)
|
||||
|
||||
|
@ -49,7 +49,7 @@ class BenchmarkTest(unittest.TestCase):
|
||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||
|
||||
def test_inference_no_configs_only_pretrain(self):
|
||||
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
||||
MODEL_ID = "sgugger/tiny-distilbert-classification"
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
models=[MODEL_ID],
|
||||
training=False,
|
||||
|
@ -52,7 +52,7 @@ class TFBenchmarkTest(unittest.TestCase):
|
||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||
|
||||
def test_inference_no_configs_only_pretrain(self):
|
||||
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
||||
MODEL_ID = "sgugger/tiny-distilbert-classification"
|
||||
benchmark_args = TensorFlowBenchmarkArguments(
|
||||
models=[MODEL_ID],
|
||||
training=False,
|
||||
|
@ -22,9 +22,7 @@ from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||
|
||||
class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "zero-shot-classification"
|
||||
small_models = [
|
||||
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
||||
] # Models tested without the @slow decorator
|
||||
small_models = ["sgugger/tiny-distilbert-classification"] # Models tested without the @slow decorator
|
||||
large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator
|
||||
valid_inputs = [
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
|
||||
|
Loading…
Reference in New Issue
Block a user