mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix weight loading issue (#14016)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
74e6111ba7
commit
a67d47b40c
@ -15,6 +15,7 @@
|
||||
""" Classes to support TF Encoder-Decoder architectures """
|
||||
|
||||
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
@ -254,6 +255,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
>>> # This is only for copying some specific attributes of this particular model.
|
||||
>>> model.config = _model.config
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import TFEncoderDecoderModel
|
||||
>>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
|
||||
|
||||
"""
|
||||
|
||||
from_pt = kwargs.pop("from_pt", False)
|
||||
@ -369,6 +375,14 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
|
||||
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||
|
||||
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
|
||||
if kwargs_encoder.get("from_pt", None):
|
||||
del kwargs_encoder["from_pt"]
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
encoder.save_pretrained(tmp_dirname)
|
||||
del encoder
|
||||
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
if decoder_pretrained_model_name_or_path is None:
|
||||
@ -397,6 +411,14 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
|
||||
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
|
||||
if kwargs_decoder.get("from_pt", None):
|
||||
del kwargs_decoder["from_pt"]
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
decoder.save_pretrained(tmp_dirname)
|
||||
del decoder
|
||||
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)
|
||||
|
||||
# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
|
||||
if encoder.name != "encoder":
|
||||
raise ValueError("encoder model must be created with the name `encoder`.")
|
||||
|
@ -457,6 +457,14 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
|
||||
|
||||
# Test with the TF checkpoint
|
||||
model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
|
||||
|
||||
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
|
||||
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
||||
@ -785,6 +793,16 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
|
||||
self.assertAlmostEqual(max_diff, 0.0, places=3)
|
||||
|
||||
# Make sure `from_pretrained` following `save_pretrained` work and give the same result
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
encoder_decoder_tf.save_pretrained(tmp_dirname)
|
||||
encoder_decoder_tf = TFEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||
|
||||
logits_tf_2 = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||
|
||||
max_diff = np.max(np.abs(logits_tf_2.numpy() - logits_tf.numpy()))
|
||||
self.assertAlmostEqual(max_diff, 0.0, places=3)
|
||||
|
||||
# TensorFlow => PyTorch
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
encoder_decoder_tf.save_pretrained(tmp_dirname)
|
||||
|
Loading…
Reference in New Issue
Block a user