mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add bart-base (#5014)
This commit is contained in:
parent
7b685f5229
commit
a9f1fc6c94
@ -278,6 +278,8 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
|||||||
| Bart | ``facebook/bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters |
|
| Bart | ``facebook/bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters |
|
||||||
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
|
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
| | ``facebook/bart-base`` | | 12-layer, 768-hidden, 16-heads, 139M parameters |
|
||||||
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``facebook/bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
|
| | ``facebook/bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
|
||||||
| | | | bart-large base architecture with a classification head, finetuned on MNLI |
|
| | | | bart-large base architecture with a classification head, finetuned on MNLI |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
|
|||||||
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
|
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
|
||||||
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
|
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
|
||||||
_all_bart_models = [
|
_all_bart_models = [
|
||||||
|
"facebook/bart-base",
|
||||||
"facebook/bart-large",
|
"facebook/bart-large",
|
||||||
"facebook/bart-large-mnli",
|
"facebook/bart-large-mnli",
|
||||||
"facebook/bart-large-cnn",
|
"facebook/bart-large-cnn",
|
||||||
|
@ -40,6 +40,7 @@ if is_torch_available():
|
|||||||
BartTokenizer,
|
BartTokenizer,
|
||||||
MBartTokenizer,
|
MBartTokenizer,
|
||||||
BatchEncoding,
|
BatchEncoding,
|
||||||
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.modeling_bart import (
|
from transformers.modeling_bart import (
|
||||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
@ -565,6 +566,22 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_bart_base_mask_filling(self):
|
||||||
|
pbase = pipeline(task="fill-mask", model="facebook/bart-base")
|
||||||
|
src_text = [" I went to the <mask>."]
|
||||||
|
results = [x["token_str"] for x in pbase(src_text)]
|
||||||
|
expected_results = ["Ġbathroom", "Ġrestroom", "Ġhospital", "Ġkitchen", "Ġcar"]
|
||||||
|
self.assertListEqual(results, expected_results)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_bart_large_mask_filling(self):
|
||||||
|
pbase = pipeline(task="fill-mask", model="facebook/bart-large")
|
||||||
|
src_text = [" I went to the <mask>."]
|
||||||
|
results = [x["token_str"] for x in pbase(src_text)]
|
||||||
|
expected_results = ["Ġbathroom", "Ġgym", "Ġwrong", "Ġmovies", "Ġhospital"]
|
||||||
|
self.assertListEqual(results, expected_results)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_mnli_inference(self):
|
def test_mnli_inference(self):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user