From a9f1fc6c94e6d3c489d54dadbab60c612e1d7fe2 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 15 Jun 2020 13:29:26 -0400 Subject: [PATCH] Add bart-base (#5014) --- docs/source/pretrained_models.rst | 2 ++ src/transformers/tokenization_bart.py | 1 + tests/test_modeling_bart.py | 17 +++++++++++++++++ 3 files changed, 20 insertions(+) diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index d6ecda5f145..44e4dded6db 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -278,6 +278,8 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | Bart | ``facebook/bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters | | | | (see `details `_) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``facebook/bart-base`` | | 12-layer, 768-hidden, 16-heads, 139M parameters | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``facebook/bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters | | | | | bart-large base architecture with a classification head, finetuned on MNLI | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index e2297aa8d68..89cfaf1cffc 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json" merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt" _all_bart_models = [ + "facebook/bart-base", "facebook/bart-large", "facebook/bart-large-mnli", "facebook/bart-large-cnn", diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 60800be6d3a..c48f20dc084 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -40,6 +40,7 @@ if is_torch_available(): BartTokenizer, MBartTokenizer, BatchEncoding, + pipeline, ) from transformers.modeling_bart import ( BART_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -565,6 +566,22 @@ class BartModelIntegrationTests(unittest.TestCase): ) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) + @slow + def test_bart_base_mask_filling(self): + pbase = pipeline(task="fill-mask", model="facebook/bart-base") + src_text = [" I went to the ."] + results = [x["token_str"] for x in pbase(src_text)] + expected_results = ["Ġbathroom", "Ġrestroom", "Ġhospital", "Ġkitchen", "Ġcar"] + self.assertListEqual(results, expected_results) + + @slow + def test_bart_large_mask_filling(self): + pbase = pipeline(task="fill-mask", model="facebook/bart-large") + src_text = [" I went to the ."] + results = [x["token_str"] for x in pbase(src_text)] + expected_results = ["Ġbathroom", "Ġgym", "Ġwrong", "Ġmovies", "Ġhospital"] + self.assertListEqual(results, expected_results) + @slow def test_mnli_inference(self):