mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Regression test for pegasus bugfix (#6606)
This commit is contained in:
parent
86c07e634f
commit
5bf4465e6c
@ -22,6 +22,7 @@ from .file_utils import add_start_docstrings_to_callable
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# These config values do not vary between checkpoints
|
||||||
DEFAULTS = dict(
|
DEFAULTS = dict(
|
||||||
vocab_size=96103,
|
vocab_size=96103,
|
||||||
max_position_embeddings=512,
|
max_position_embeddings=512,
|
||||||
@ -46,6 +47,47 @@ DEFAULTS = dict(
|
|||||||
num_beams=8,
|
num_beams=8,
|
||||||
activation_function="relu",
|
activation_function="relu",
|
||||||
)
|
)
|
||||||
|
# Config values that vary between checkpoints: for testing and conversion
|
||||||
|
max_gen_length = {
|
||||||
|
# See appendix C of paper
|
||||||
|
"xsum": 64,
|
||||||
|
"cnn_dailymail": 128,
|
||||||
|
"newsroom": 128,
|
||||||
|
"wikihow": 256,
|
||||||
|
"multi_news": 256,
|
||||||
|
"reddit_tifu": 128,
|
||||||
|
"big_patent": 256,
|
||||||
|
"arxiv": 256,
|
||||||
|
"pubmed": 256,
|
||||||
|
"gigaword": 32,
|
||||||
|
"aeslc": 32,
|
||||||
|
"billsum": 256,
|
||||||
|
"large": 256, # @sshleifer chose arbitrarily
|
||||||
|
}
|
||||||
|
max_model_length = {
|
||||||
|
"xsum": 512,
|
||||||
|
"cnn_dailymail": 1024,
|
||||||
|
"newsroom": 512,
|
||||||
|
"wikihow": 512,
|
||||||
|
"multi_news": 1024,
|
||||||
|
"reddit_tifu": 512,
|
||||||
|
"big_patent": 1024,
|
||||||
|
"arxiv": 1024,
|
||||||
|
"pubmed": 1024,
|
||||||
|
"gigaword": 128,
|
||||||
|
"aeslc": 512,
|
||||||
|
"billsum": 1024,
|
||||||
|
"large": 1024,
|
||||||
|
}
|
||||||
|
expected_alpha = {
|
||||||
|
"multinews": 0.9,
|
||||||
|
"wikihow": 0.6,
|
||||||
|
"reddit_tifu": 0.6,
|
||||||
|
"big_patent": 0.7,
|
||||||
|
"gigaword": 0.6,
|
||||||
|
"aeslc": 0.6,
|
||||||
|
"billsum": 0.6,
|
||||||
|
} # otherwise 0.8
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
|
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
|
||||||
@ -56,7 +98,3 @@ class PegasusConfig(BartConfig):
|
|||||||
"""
|
"""
|
||||||
model_type = "pegasus"
|
model_type = "pegasus"
|
||||||
# The implementation of the config object is in BartConfig
|
# The implementation of the config object is in BartConfig
|
||||||
|
|
||||||
@property
|
|
||||||
def default_config_parameters(self):
|
|
||||||
return DEFAULTS
|
|
||||||
|
@ -22,7 +22,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
||||||
from transformers.configuration_pegasus import DEFAULTS
|
from transformers.configuration_pegasus import DEFAULTS, expected_alpha, max_gen_length, max_model_length
|
||||||
|
|
||||||
|
|
||||||
PATTERNS = [
|
PATTERNS = [
|
||||||
@ -52,47 +52,7 @@ def rename_state_dict_key(k):
|
|||||||
|
|
||||||
|
|
||||||
# See appendix C of paper for all hyperparams
|
# See appendix C of paper for all hyperparams
|
||||||
max_gen_length = {
|
|
||||||
# See appendix C of paper
|
|
||||||
"xsum": 64,
|
|
||||||
"cnn_dailymail": 128,
|
|
||||||
"newsroom": 128,
|
|
||||||
"wikihow": 256,
|
|
||||||
"multi_news": 256,
|
|
||||||
"reddit_tifu": 128,
|
|
||||||
"big_patent": 256,
|
|
||||||
"arxiv": 256,
|
|
||||||
"pubmed": 256,
|
|
||||||
"gigaword": 32,
|
|
||||||
"aeslc": 32,
|
|
||||||
"billsum": 256,
|
|
||||||
"large": 256, # @sshleifer chose arbitrarily
|
|
||||||
}
|
|
||||||
max_model_length = {
|
|
||||||
"xsum": 512,
|
|
||||||
"cnn_dailymail": 1024,
|
|
||||||
"newsroom": 512,
|
|
||||||
"wikihow": 512,
|
|
||||||
"multi_news": 1024,
|
|
||||||
"reddit_tifu": 512,
|
|
||||||
"big_patent": 1024,
|
|
||||||
"arxiv": 1024,
|
|
||||||
"pubmed": 1024,
|
|
||||||
"gigaword": 128,
|
|
||||||
"aeslc": 512,
|
|
||||||
"billsum": 1024,
|
|
||||||
"large": 1024,
|
|
||||||
}
|
|
||||||
|
|
||||||
expected_alpha = {
|
|
||||||
"multinews": 0.9,
|
|
||||||
"wikihow": 0.6,
|
|
||||||
"reddit_tifu": 0.6,
|
|
||||||
"big_patent": 0.7,
|
|
||||||
"gigaword": 0.6,
|
|
||||||
"aeslc": 0.6,
|
|
||||||
"billsum": 0.6,
|
|
||||||
} # otherwise 0.8
|
|
||||||
# TODO(SS): one constant
|
# TODO(SS): one constant
|
||||||
|
|
||||||
|
|
||||||
@ -151,7 +111,11 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
|
|||||||
|
|
||||||
# convert model
|
# convert model
|
||||||
tf_weights = get_tf_weights_as_numpy(ckpt_path)
|
tf_weights = get_tf_weights_as_numpy(ckpt_path)
|
||||||
cfg_updates = dict(max_length=max_gen_length[dataset], length_penalty=expected_alpha.get(dataset, 0.8))
|
cfg_updates = dict(
|
||||||
|
max_length=max_gen_length[dataset],
|
||||||
|
length_penalty=expected_alpha.get(dataset, 0.8),
|
||||||
|
max_position_embeddings=desired_max_model_length,
|
||||||
|
)
|
||||||
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
|
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
|
||||||
torch_model.save_pretrained(save_dir)
|
torch_model.save_pretrained(save_dir)
|
||||||
|
|
||||||
|
@ -23,6 +23,13 @@ from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
|
|||||||
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
|
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
|
||||||
class PegasusForConditionalGeneration(BartForConditionalGeneration):
|
class PegasusForConditionalGeneration(BartForConditionalGeneration):
|
||||||
config_class = PegasusConfig
|
config_class = PegasusConfig
|
||||||
|
authorized_missing_keys = [
|
||||||
|
r"final_logits_bias",
|
||||||
|
r"encoder\.version",
|
||||||
|
r"decoder\.version",
|
||||||
|
r"model.encoder.embed_positions",
|
||||||
|
"model.decoder.embed_positions",
|
||||||
|
]
|
||||||
r"""
|
r"""
|
||||||
Pytorch version of google's pegasus model for summarization.
|
Pytorch version of google's pegasus model for summarization.
|
||||||
Model API is identical to BartForConditionalGeneration.
|
Model API is identical to BartForConditionalGeneration.
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoConfig, is_torch_available
|
from transformers import AutoConfig, AutoTokenizer, is_torch_available
|
||||||
|
from transformers.configuration_pegasus import max_gen_length, max_model_length
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
@ -50,28 +51,28 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
|||||||
|
|
||||||
class PegasusConfigTests(unittest.TestCase):
|
class PegasusConfigTests(unittest.TestCase):
|
||||||
def test_all_config_max_lengths(self):
|
def test_all_config_max_lengths(self):
|
||||||
expected_max_length = {
|
|
||||||
# See appendix C of paper
|
|
||||||
"xsum": 64,
|
|
||||||
"cnn_dailymail": 128,
|
|
||||||
"newsroom": 128,
|
|
||||||
"wikihow": 256,
|
|
||||||
"multi_news": 256,
|
|
||||||
"reddit_tifu": 128,
|
|
||||||
"big_patent": 256,
|
|
||||||
"arxiv": 256,
|
|
||||||
"pubmed": 256,
|
|
||||||
"gigaword": 32,
|
|
||||||
"aeslc": 32,
|
|
||||||
"billsum": 256,
|
|
||||||
}
|
|
||||||
failures = []
|
failures = []
|
||||||
pegasus_prefix = "google/pegasus"
|
pegasus_prefix = "google/pegasus"
|
||||||
for dataset, max_len in expected_max_length.items():
|
for dataset, max_len in max_gen_length.items():
|
||||||
mname = f"{pegasus_prefix}-{dataset}"
|
mname = f"{pegasus_prefix}-{dataset}"
|
||||||
cfg = AutoConfig.from_pretrained(mname)
|
cfg = AutoConfig.from_pretrained(mname)
|
||||||
|
|
||||||
if cfg.max_length != max_len:
|
if cfg.max_length != max_len:
|
||||||
failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}")
|
failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}")
|
||||||
|
|
||||||
|
if cfg.max_position_embeddings < max_model_length[dataset]:
|
||||||
|
# otherwise you get IndexError for e.g. position 513
|
||||||
|
# see https://github.com/huggingface/transformers/issues/6599
|
||||||
|
failures.append(
|
||||||
|
f"config for {mname} had max_position_embeddings: {cfg.max_position_embeddings}, expected {max_model_length[dataset]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(mname)
|
||||||
|
if max_model_length[dataset] != tokenizer.model_max_length:
|
||||||
|
failures.append(
|
||||||
|
f"tokenizer.model_max_length {tokenizer.model_max_length} expected {max_model_length[dataset]}"
|
||||||
|
)
|
||||||
|
|
||||||
if failures == []:
|
if failures == []:
|
||||||
return
|
return
|
||||||
# error
|
# error
|
||||||
|
Loading…
Reference in New Issue
Block a user