diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index 077e257a58c..3e019b24767 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -26,6 +26,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { "bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json", "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json", "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json", + "bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json", } diff --git a/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py index e93a5b35c36..22fb047db7f 100644 --- a/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py @@ -17,6 +17,7 @@ import argparse import logging +import os from pathlib import Path import fairseq @@ -30,10 +31,11 @@ from transformers import ( BartModel, BartTokenizer, ) +from transformers.modeling_bart import _make_linear_from_emb -FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"] - +FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"] +extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification} if version.parse(fairseq.__version__) < version.parse("0.9.0"): raise Exception("requires fairseq >= 0.9.0") @@ -57,62 +59,79 @@ def rename_key(dct, old, new): dct[new] = val -def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path): +def load_xsum_checkpoint(checkpoint_path): + """Checkpoint path should end in model.pt""" + sd = torch.load(checkpoint_path, map_location="cpu") + hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval() + hub_interface.model.load_state_dict(sd["model"]) + return hub_interface + + +@torch.no_grad() +def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None): """ Copy/paste/tweak model's weights to our BERT structure. """ - bart = torch.hub.load("pytorch/fairseq", checkpoint_path) - bart.eval() # disable dropout + if not os.path.exists(checkpoint_path): + bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval() + else: + bart = load_xsum_checkpoint(checkpoint_path) + bart.model.upgrade_state_dict(bart.model.state_dict()) - hf_model_name = checkpoint_path.replace(".", "-") - config = BartConfig.from_pretrained(hf_model_name) + if hf_checkpoint_name is None: + hf_checkpoint_name = checkpoint_path.replace(".", "-") + config = BartConfig.from_pretrained(hf_checkpoint_name) tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0) - tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0) + tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0) assert torch.eq(tokens, tokens2).all() - if checkpoint_path in ["bart.large", "bart.large.cnn"]: - state_dict = bart.model.state_dict() - for k in IGNORE_KEYS: - state_dict.pop(k, None) - state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] - model = BartModel(config) - their_output = bart.extract_features(tokens) - else: # MNLI Case + if checkpoint_path == "bart.large.mnli": state_dict = bart.state_dict() - for k in IGNORE_KEYS: - state_dict.pop(k, None) + remove_ignore_keys_(state_dict) state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"] for src, dest in rename_keys: rename_key(state_dict, src, dest) - model = BartForSequenceClassification(config) - their_output = bart.predict("mnli", tokens, return_logits=True) + model = BartForSequenceClassification(config).eval() + model.load_state_dict(state_dict) + fairseq_output = bart.predict("mnli", tokens, return_logits=True) + new_model_outputs = model(tokens)[0] # logits + else: # no classification heads to worry about + state_dict = bart.model.state_dict() + remove_ignore_keys_(state_dict) + state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] + fairseq_output = bart.extract_features(tokens) + if hf_checkpoint_name == "bart-large": + model = BartModel(config).eval() + model.load_state_dict(state_dict) + new_model_outputs = model(tokens).model[0] + else: + model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt + model.model.load_state_dict(state_dict) + if hasattr(model, "lm_head"): + model.lm_head = _make_linear_from_emb(model.model.shared) + new_model_outputs = model.model(tokens)[0] - # Load state dict - model.load_state_dict(state_dict) - model.eval() # Check results - - if checkpoint_path == "bart.large.cnn": - model = BartForConditionalGeneration(config, base_model=model) - assert "lm_head.weight" in model.state_dict() - assert model.lm_head.out_features == config.max_position_embeddings - model.eval() - our_outputs = model.model(tokens)[0] - else: - our_outputs = model(tokens)[0] - assert their_output.shape == our_outputs.shape - assert (their_output == our_outputs).all().item() + assert fairseq_output.shape == new_model_outputs.shape + assert (fairseq_output == new_model_outputs).all().item() Path(pytorch_dump_folder_path).mkdir(exist_ok=True) model.save_pretrained(pytorch_dump_folder_path) +def remove_ignore_keys_(state_dict): + for k in IGNORE_KEYS: + state_dict.pop(k, None) + + if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters - parser.add_argument("fairseq_path", choices=FAIRSEQ_MODELS, type=str, help="") - - parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") - args = parser.parse_args() - convert_bart_checkpoint( - args.fairseq_path, args.pytorch_dump_folder_path, + parser.add_argument( + "fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem." ) + parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument( + "--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum" + ) + args = parser.parse_args() + convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index d2f92b00549..cf88f9d5942 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -34,6 +34,7 @@ BART_PRETRAINED_MODEL_ARCHIVE_MAP = { "bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin", "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin", "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin", + "bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin", } BART_START_DOCSTRING = r""" diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index f5c0d8f1dd0..76f184f50d9 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -19,7 +19,7 @@ from .tokenization_roberta import RobertaTokenizer # vocab and merges same as roberta vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json" merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt" -_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn"] +_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn", "bart-large-xsum"] class BartTokenizer(RobertaTokenizer): diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index c463e4df3b8..db7ce633170 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -450,6 +450,38 @@ class BartModelIntegrationTests(unittest.TestCase): model = BartModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) + @slow + def test_xsum_summarization_same_as_fairseq(self): + model = BartForConditionalGeneration.from_pretrained("bart-large-xsum").to(torch_device) + tok = BartTokenizer.from_pretrained("bart-large") + + PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.""" + EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state." + dct = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",) + + hypotheses_batch = model.generate( + input_ids=dct["input_ids"].to(torch_device), + attention_mask=dct["attention_mask"].to(torch_device), + num_beams=2, + max_length=62, + min_length=11, + length_penalty=1.0, + no_repeat_ngram_size=3, + early_stopping=True, + decoder_start_token_id=model.config.eos_token_ids[0], + ) + + decoded = [ + tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch + ] + self.assertEqual(EXPECTED_SUMMARY, decoded[0]) + + def test_xsum_config_generation_params(self): + config = BartConfig.from_pretrained("bart-large-xsum") + expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0) + config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()} + self.assertDictEqual(expected_params, config_params) + @slow def test_cnn_summarization_same_as_fairseq(self): hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)