transformers/docs/source/model_doc/bart.rst
Daniel Stancl 4a51b1dd9b
FlaxBart (#11537)
* Start working on FlaxBart

* Create modeling_flax_bart.py

* Write FlaxBartAttention

* Add FlaxBartEncoderLayer

* Add FlaxBartDecoderLayer and some typing

* Add helepr function for FlaxBart

* shift_tokens_right

* _make_causal_mask

* _expand_mask

* Add PositionalEmbedding and fix init_std naming

* Add FlaxBartPretrainedModel

* Add FlaxBartEncoder

* Add FlaxBartEncoder

* Add FlaxBartEncoder among modules to be imported

* YET WE CANNOT INITIALIZE THAT!! :(

* Make BartEncoder working

Change BartEncoder to instance of nn.Module so far

* Add FlaxBartDecoder

* Add FlaxBartModel

* TODO to make model run -> Prepapre model inputs

* Resolve padding

* Add FlaxBartModel

* Add FlaxBartModel into importable modules

* Remove FlaxBartEncoder and FlaxBartDecoder from importable modules

* make style; not properly working

* make style; make quality not pass due to some import I left

* Remove TODO for padding_idx in nn.Embed so far

* Add FlaxBartForConditionalGeneration

* Incorporate Flax model output classes, i.e. return_dict

* Add another models and incorporate use_cache arg

* Add FlaxBartForSequenceClassification and FlaxBartForQuestionAnswering

* Incorporate use_cache arg from PyTorch implementation

* Add all necessary Flax output utils

* Add FlaxBartForCausalLM; not working yet'

* Add minor improvements; still lacks some functionality

* Update docs, src and tests

* Add support of FlaxBart to docs/source

* Fix some bugs in FlaxBart souce code

* Add some neccessary tests for FlaxBart models - jit_compilation not passing

* Fix tests and add test_head_masking

* Fix tests for @jax.jit computation

* Add test_head_masking

* Migrate FlaxBart tests from jax.numpy to numpy

* Remove FlaxBartForCausalLM

* Clean repo

* fix bart model weight structure

* Fix FlaxBartForSequenceClassification

Slicing is not possible to use below jit, therefore, selecting sentence
representation from hidden_states must be changed.

* Allow FlaxBartForSequenceClassification for testing pt_flax equivalence

* Allow testing for FlaxBartForQA for pt_flax equivalence

* Add a comment to FlaxBartForSequenceClassification + change noise from 1e-3 to 1e-6

* remove past_key_values

* remove inputs_mebeds and make input_ids required

* add position ids

* re-write attention layer

* fix dataclass

* fix pos embeds and attention output

* fix pos embeds

* expose encode method

* expose decode method

* move docstring to top

* add cache for causal attn layer

* remove head masking for now

* s2s greedy search first pass

* boom boom

* fix typos

* fix greedy generate for bart

* use encoder, decoder layers instead of num_hidden_layers

* handle encoder_outputs

* cleanup

* simplify decoding

* more clean-up

* typos

* Change header + add {decoder_,}position_ids into 2 models

* add BartConfig

* fix existing tests

* add encode, decode methods

* Fix shift_tokens_right for JIT compilation + clarify one condition

* fix decode

* encoder => encode

* simplify generate

* add tests for encode and decode

* style

* add tests for cache

* fix equivalence tests

* sample generate now works with seq2seq

* generation tests

* initialize dense layers

* docstring and cleanup

* quality

* remove get/set input_embeddings

* address Patricks suggestions

* decode for every model, remove encoder_outputs from call

* update tests accordingly

* decode returns only decoder outputs and logits

* fix arguments

* doc encode, decode methods

* correct base_model_prefix

* fix test for seq classif model

* fix docs

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
2021-06-14 15:16:08 +05:30

183 lines
7.9 KiB
ReStructuredText

..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
BART
-----------------------------------------------------------------------------------------------------------------------
**DISCLAIMER:** If you see something strange, file a `Github Issue
<https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
@patrickvonplaten
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The Bart model was proposed in `BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation,
Translation, and Comprehension <https://arxiv.org/abs/1910.13461>`__ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan
Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019.
According to the abstract,
- Bart uses a standard seq2seq/machine translation architecture with a bidirectional encoder (like BERT) and a
left-to-right decoder (like GPT).
- The pretraining task involves randomly shuffling the order of the original sentences and a novel in-filling scheme,
where spans of text are replaced with a single mask token.
- BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It
matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new
state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains
of up to 6 ROUGE.
This model was contributed by `sshleifer <https://huggingface.co/sshleifer>`__. The Authors' code can be found `here
<https://github.com/pytorch/fairseq/tree/master/examples/bart>`__.
Examples
_______________________________________________________________________________________________________________________
- Examples and scripts for fine-tuning BART and other models for sequence to sequence tasks can be found in
:prefix_link:`examples/pytorch/summarization/ <examples/pytorch/summarization/README.md>`.
- An example of how to train :class:`~transformers.BartForConditionalGeneration` with a Hugging Face :obj:`datasets`
object can be found in this `forum discussion
<https://discuss.huggingface.co/t/train-bart-for-conditional-generation-e-g-summarization/1904>`__.
- `Distilled checkpoints <https://huggingface.co/models?search=distilbart>`__ are described in this `paper
<https://arxiv.org/abs/2010.13002>`__.
Implementation Notes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- Bart doesn't use :obj:`token_type_ids` for sequence classification. Use :class:`~transformers.BartTokenizer` or
:meth:`~transformers.BartTokenizer.encode` to get the proper splitting.
- The forward pass of :class:`~transformers.BartModel` will create the ``decoder_input_ids`` if they are not passed.
This is different than some other modeling APIs. A typical use case of this feature is mask filling.
- Model predictions are intended to be identical to the original implementation when
:obj:`force_bos_token_to_be_generated=True`. This only works, however, if the string you pass to
:func:`fairseq.encode` starts with a space.
- :meth:`~transformers.generation_utils.GenerationMixin.generate` should be used for conditional generation tasks like
summarization, see the example in that docstrings.
- Models that load the `facebook/bart-large-cnn` weights will not have a :obj:`mask_token_id`, or be able to perform
mask-filling tasks.
Mask Filling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The :obj:`facebook/bart-base` and :obj:`facebook/bart-large` checkpoints can be used to fill multi-token masks.
.. code-block::
from transformers import BartForConditionalGeneration, BartTokenizer
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", force_bos_token_to_be_generated=True)
tok = BartTokenizer.from_pretrained("facebook/bart-large")
example_english_phrase = "UN Chief Says There Is No <mask> in Syria"
batch = tok(example_english_phrase, return_tensors='pt')
generated_ids = model.generate(batch['input_ids'])
assert tok.batch_decode(generated_ids, skip_special_tokens=True) == ['UN Chief Says There Is No Plan to Stop Chemical Weapons in Syria']
BartConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartConfig
:members:
BartTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartTokenizer
:members:
BartTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartTokenizerFast
:members:
BartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartModel
:members: forward
BartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForConditionalGeneration
:members: forward
BartForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForSequenceClassification
:members: forward
BartForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForQuestionAnswering
:members: forward
BartForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForCausalLM
:members: forward
TFBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFBartModel
:members: call
TFBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFBartForConditionalGeneration
:members: call
FlaxBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartModel
:members: __call__, encode, decode
FlaxBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartForConditionalGeneration
:members: __call__, encode, decode
FlaxBartForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartForSequenceClassification
:members: __call__, encode, decode
FlaxBartForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartForQuestionAnswering
:members: __call__, encode, decode