[BART] add bart-large-xsum weights (#3422)

This commit is contained in:
Sam Shleifer 2020-03-29 10:51:13 -04:00 committed by GitHub
parent 601ac5b1dc
commit f6a23d1911
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 94 additions and 41 deletions

View File

@ -26,6 +26,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
}

View File

@ -17,6 +17,7 @@
import argparse
import logging
import os
from pathlib import Path
import fairseq
@ -30,10 +31,11 @@ from transformers import (
BartModel,
BartTokenizer,
)
from transformers.modeling_bart import _make_linear_from_emb
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
@ -57,62 +59,79 @@ def rename_key(dct, old, new):
dct[new] = val
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
def load_xsum_checkpoint(checkpoint_path):
"""Checkpoint path should end in model.pt"""
sd = torch.load(checkpoint_path, map_location="cpu")
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
hub_interface.model.load_state_dict(sd["model"])
return hub_interface
@torch.no_grad()
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
bart = torch.hub.load("pytorch/fairseq", checkpoint_path)
bart.eval() # disable dropout
if not os.path.exists(checkpoint_path):
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
else:
bart = load_xsum_checkpoint(checkpoint_path)
bart.model.upgrade_state_dict(bart.model.state_dict())
hf_model_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_model_name)
if hf_checkpoint_name is None:
hf_checkpoint_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_checkpoint_name)
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
assert torch.eq(tokens, tokens2).all()
if checkpoint_path in ["bart.large", "bart.large.cnn"]:
state_dict = bart.model.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = BartModel(config)
their_output = bart.extract_features(tokens)
else: # MNLI Case
if checkpoint_path == "bart.large.mnli":
state_dict = bart.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
remove_ignore_keys_(state_dict)
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
model = BartForSequenceClassification(config)
their_output = bart.predict("mnli", tokens, return_logits=True)
model = BartForSequenceClassification(config).eval()
model.load_state_dict(state_dict)
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
new_model_outputs = model(tokens)[0] # logits
else: # no classification heads to worry about
state_dict = bart.model.state_dict()
remove_ignore_keys_(state_dict)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
fairseq_output = bart.extract_features(tokens)
if hf_checkpoint_name == "bart-large":
model = BartModel(config).eval()
model.load_state_dict(state_dict)
new_model_outputs = model(tokens).model[0]
else:
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
model.model.load_state_dict(state_dict)
if hasattr(model, "lm_head"):
model.lm_head = _make_linear_from_emb(model.model.shared)
new_model_outputs = model.model(tokens)[0]
# Load state dict
model.load_state_dict(state_dict)
model.eval()
# Check results
if checkpoint_path == "bart.large.cnn":
model = BartForConditionalGeneration(config, base_model=model)
assert "lm_head.weight" in model.state_dict()
assert model.lm_head.out_features == config.max_position_embeddings
model.eval()
our_outputs = model.model(tokens)[0]
else:
our_outputs = model(tokens)[0]
assert their_output.shape == our_outputs.shape
assert (their_output == our_outputs).all().item()
assert fairseq_output.shape == new_model_outputs.shape
assert (fairseq_output == new_model_outputs).all().item()
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
def remove_ignore_keys_(state_dict):
for k in IGNORE_KEYS:
state_dict.pop(k, None)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("fairseq_path", choices=FAIRSEQ_MODELS, type=str, help="")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_bart_checkpoint(
args.fairseq_path, args.pytorch_dump_folder_path,
parser.add_argument(
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
)
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
)
args = parser.parse_args()
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)

View File

@ -34,6 +34,7 @@ BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin",
}
BART_START_DOCSTRING = r"""

View File

@ -19,7 +19,7 @@ from .tokenization_roberta import RobertaTokenizer
# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn"]
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn", "bart-large-xsum"]
class BartTokenizer(RobertaTokenizer):

View File

@ -450,6 +450,38 @@ class BartModelIntegrationTests(unittest.TestCase):
model = BartModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
@slow
def test_xsum_summarization_same_as_fairseq(self):
model = BartForConditionalGeneration.from_pretrained("bart-large-xsum").to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
dct = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",)
hypotheses_batch = model.generate(
input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=2,
max_length=62,
min_length=11,
length_penalty=1.0,
no_repeat_ngram_size=3,
early_stopping=True,
decoder_start_token_id=model.config.eos_token_ids[0],
)
decoded = [
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
]
self.assertEqual(EXPECTED_SUMMARY, decoded[0])
def test_xsum_config_generation_params(self):
config = BartConfig.from_pretrained("bart-large-xsum")
expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0)
config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
self.assertDictEqual(expected_params, config_params)
@slow
def test_cnn_summarization_same_as_fairseq(self):
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)