mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00

* Start working on FlaxBart * Create modeling_flax_bart.py * Write FlaxBartAttention * Add FlaxBartEncoderLayer * Add FlaxBartDecoderLayer and some typing * Add helepr function for FlaxBart * shift_tokens_right * _make_causal_mask * _expand_mask * Add PositionalEmbedding and fix init_std naming * Add FlaxBartPretrainedModel * Add FlaxBartEncoder * Add FlaxBartEncoder * Add FlaxBartEncoder among modules to be imported * YET WE CANNOT INITIALIZE THAT!! :( * Make BartEncoder working Change BartEncoder to instance of nn.Module so far * Add FlaxBartDecoder * Add FlaxBartModel * TODO to make model run -> Prepapre model inputs * Resolve padding * Add FlaxBartModel * Add FlaxBartModel into importable modules * Remove FlaxBartEncoder and FlaxBartDecoder from importable modules * make style; not properly working * make style; make quality not pass due to some import I left * Remove TODO for padding_idx in nn.Embed so far * Add FlaxBartForConditionalGeneration * Incorporate Flax model output classes, i.e. return_dict * Add another models and incorporate use_cache arg * Add FlaxBartForSequenceClassification and FlaxBartForQuestionAnswering * Incorporate use_cache arg from PyTorch implementation * Add all necessary Flax output utils * Add FlaxBartForCausalLM; not working yet' * Add minor improvements; still lacks some functionality * Update docs, src and tests * Add support of FlaxBart to docs/source * Fix some bugs in FlaxBart souce code * Add some neccessary tests for FlaxBart models - jit_compilation not passing * Fix tests and add test_head_masking * Fix tests for @jax.jit computation * Add test_head_masking * Migrate FlaxBart tests from jax.numpy to numpy * Remove FlaxBartForCausalLM * Clean repo * fix bart model weight structure * Fix FlaxBartForSequenceClassification Slicing is not possible to use below jit, therefore, selecting sentence representation from hidden_states must be changed. * Allow FlaxBartForSequenceClassification for testing pt_flax equivalence * Allow testing for FlaxBartForQA for pt_flax equivalence * Add a comment to FlaxBartForSequenceClassification + change noise from 1e-3 to 1e-6 * remove past_key_values * remove inputs_mebeds and make input_ids required * add position ids * re-write attention layer * fix dataclass * fix pos embeds and attention output * fix pos embeds * expose encode method * expose decode method * move docstring to top * add cache for causal attn layer * remove head masking for now * s2s greedy search first pass * boom boom * fix typos * fix greedy generate for bart * use encoder, decoder layers instead of num_hidden_layers * handle encoder_outputs * cleanup * simplify decoding * more clean-up * typos * Change header + add {decoder_,}position_ids into 2 models * add BartConfig * fix existing tests * add encode, decode methods * Fix shift_tokens_right for JIT compilation + clarify one condition * fix decode * encoder => encode * simplify generate * add tests for encode and decode * style * add tests for cache * fix equivalence tests * sample generate now works with seq2seq * generation tests * initialize dense layers * docstring and cleanup * quality * remove get/set input_embeddings * address Patricks suggestions * decode for every model, remove encoder_outputs from call * update tests accordingly * decode returns only decoder outputs and logits * fix arguments * doc encode, decode methods * correct base_model_prefix * fix test for seq classif model * fix docs Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
183 lines
7.9 KiB
ReStructuredText
183 lines
7.9 KiB
ReStructuredText
..
|
|
Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
BART
|
|
-----------------------------------------------------------------------------------------------------------------------
|
|
|
|
**DISCLAIMER:** If you see something strange, file a `Github Issue
|
|
<https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
|
@patrickvonplaten
|
|
|
|
Overview
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
The Bart model was proposed in `BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation,
|
|
Translation, and Comprehension <https://arxiv.org/abs/1910.13461>`__ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan
|
|
Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019.
|
|
|
|
According to the abstract,
|
|
|
|
- Bart uses a standard seq2seq/machine translation architecture with a bidirectional encoder (like BERT) and a
|
|
left-to-right decoder (like GPT).
|
|
- The pretraining task involves randomly shuffling the order of the original sentences and a novel in-filling scheme,
|
|
where spans of text are replaced with a single mask token.
|
|
- BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It
|
|
matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new
|
|
state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains
|
|
of up to 6 ROUGE.
|
|
|
|
This model was contributed by `sshleifer <https://huggingface.co/sshleifer>`__. The Authors' code can be found `here
|
|
<https://github.com/pytorch/fairseq/tree/master/examples/bart>`__.
|
|
|
|
|
|
Examples
|
|
_______________________________________________________________________________________________________________________
|
|
|
|
- Examples and scripts for fine-tuning BART and other models for sequence to sequence tasks can be found in
|
|
:prefix_link:`examples/pytorch/summarization/ <examples/pytorch/summarization/README.md>`.
|
|
- An example of how to train :class:`~transformers.BartForConditionalGeneration` with a Hugging Face :obj:`datasets`
|
|
object can be found in this `forum discussion
|
|
<https://discuss.huggingface.co/t/train-bart-for-conditional-generation-e-g-summarization/1904>`__.
|
|
- `Distilled checkpoints <https://huggingface.co/models?search=distilbart>`__ are described in this `paper
|
|
<https://arxiv.org/abs/2010.13002>`__.
|
|
|
|
|
|
Implementation Notes
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
- Bart doesn't use :obj:`token_type_ids` for sequence classification. Use :class:`~transformers.BartTokenizer` or
|
|
:meth:`~transformers.BartTokenizer.encode` to get the proper splitting.
|
|
- The forward pass of :class:`~transformers.BartModel` will create the ``decoder_input_ids`` if they are not passed.
|
|
This is different than some other modeling APIs. A typical use case of this feature is mask filling.
|
|
- Model predictions are intended to be identical to the original implementation when
|
|
:obj:`force_bos_token_to_be_generated=True`. This only works, however, if the string you pass to
|
|
:func:`fairseq.encode` starts with a space.
|
|
- :meth:`~transformers.generation_utils.GenerationMixin.generate` should be used for conditional generation tasks like
|
|
summarization, see the example in that docstrings.
|
|
- Models that load the `facebook/bart-large-cnn` weights will not have a :obj:`mask_token_id`, or be able to perform
|
|
mask-filling tasks.
|
|
|
|
Mask Filling
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
The :obj:`facebook/bart-base` and :obj:`facebook/bart-large` checkpoints can be used to fill multi-token masks.
|
|
|
|
.. code-block::
|
|
|
|
from transformers import BartForConditionalGeneration, BartTokenizer
|
|
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", force_bos_token_to_be_generated=True)
|
|
tok = BartTokenizer.from_pretrained("facebook/bart-large")
|
|
example_english_phrase = "UN Chief Says There Is No <mask> in Syria"
|
|
batch = tok(example_english_phrase, return_tensors='pt')
|
|
generated_ids = model.generate(batch['input_ids'])
|
|
assert tok.batch_decode(generated_ids, skip_special_tokens=True) == ['UN Chief Says There Is No Plan to Stop Chemical Weapons in Syria']
|
|
|
|
|
|
|
|
BartConfig
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.BartConfig
|
|
:members:
|
|
|
|
|
|
BartTokenizer
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.BartTokenizer
|
|
:members:
|
|
|
|
|
|
BartTokenizerFast
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.BartTokenizerFast
|
|
:members:
|
|
|
|
|
|
BartModel
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.BartModel
|
|
:members: forward
|
|
|
|
|
|
BartForConditionalGeneration
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.BartForConditionalGeneration
|
|
:members: forward
|
|
|
|
|
|
BartForSequenceClassification
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.BartForSequenceClassification
|
|
:members: forward
|
|
|
|
|
|
BartForQuestionAnswering
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.BartForQuestionAnswering
|
|
:members: forward
|
|
|
|
|
|
BartForCausalLM
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.BartForCausalLM
|
|
:members: forward
|
|
|
|
|
|
TFBartModel
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TFBartModel
|
|
:members: call
|
|
|
|
|
|
TFBartForConditionalGeneration
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TFBartForConditionalGeneration
|
|
:members: call
|
|
|
|
|
|
FlaxBartModel
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.FlaxBartModel
|
|
:members: __call__, encode, decode
|
|
|
|
|
|
FlaxBartForConditionalGeneration
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.FlaxBartForConditionalGeneration
|
|
:members: __call__, encode, decode
|
|
|
|
|
|
FlaxBartForSequenceClassification
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.FlaxBartForSequenceClassification
|
|
:members: __call__, encode, decode
|
|
|
|
|
|
FlaxBartForQuestionAnswering
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.FlaxBartForQuestionAnswering
|
|
:members: __call__, encode, decode
|
|
|