Fix from_pretrained with corrupted state_dict (#12939)

* Fix from_pretrained with corrupted state_dict

* Adapt test

* Use better checkpoint

* Style

* Clean up
This commit is contained in:
Sylvain Gugger 2021-08-04 11:48:39 +02:00 committed by GitHub
parent a28da4c490
commit d4c834d2e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 5 deletions

View File

@ -1409,6 +1409,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
add_prefix = has_prefix_module and not expects_prefix_module
if remove_prefix:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
elif add_prefix:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
@ -1490,6 +1491,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError(
"The state dictionary of the model you are training to load is corrupted. Are you sure it was "
"properly saved?"
)
load(model_to_load, prefix=start_prefix)

View File

@ -49,7 +49,7 @@ class BenchmarkTest(unittest.TestCase):
self.check_results_dict_not_empty(results.memory_inference_result)
def test_inference_no_configs_only_pretrain(self):
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
MODEL_ID = "sgugger/tiny-distilbert-classification"
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID],
training=False,

View File

@ -52,7 +52,7 @@ class TFBenchmarkTest(unittest.TestCase):
self.check_results_dict_not_empty(results.memory_inference_result)
def test_inference_no_configs_only_pretrain(self):
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
MODEL_ID = "sgugger/tiny-distilbert-classification"
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=False,

View File

@ -22,9 +22,7 @@ from .test_pipelines_common import CustomInputPipelineCommonMixin
class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "zero-shot-classification"
small_models = [
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
] # Models tested without the @slow decorator
small_models = ["sgugger/tiny-distilbert-classification"] # Models tested without the @slow decorator
large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator
valid_inputs = [
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},