Add LongT5 model (#16792)

* Initial commit

* Make some fixes

* Make PT model full forward pass

* Drop TF & Flax implementation, fix copies etc

* Add Flax model and update some corresponding stuff

* Drop some TF things

* Update config and flax local attn

* Add encoder_attention_type to config

* .

* Update docs

* Do some cleansing

* Fix some issues -> make style; add some docs

* Fix position_bias + mask addition + Update tests

* Fix repo consistency

* Fix model consistency by removing flax operation over attn_mask

* [WIP] Add PT TGlobal LongT5

* .

* [WIP] Add flax tglobal model

* [WIP] Update flax model to use the right attention type in the encoder

* Fix flax tglobal model forward pass

* Make the use of global_relative_attention_bias

* Add test suites for TGlobal model

* Fix minor bugs, clean code

* Fix pt-flax equivalence though not convinced with correctness

* Fix LocalAttn implementation to match the original impl. + update READMEs

* Few updates

* Update: [Flax] improve large model init and loading #16148

* Add ckpt conversion script accoring to #16853 + handle torch device placement

* Minor updates to conversion script.

* Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM

* gpu support + dtype fix

* Apply some suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* * Remove (de)parallelize stuff
* Edit shape comments
* Update README.md
* make fix-copies

* Remove caching logic for local & tglobal attention

* Apply another batch of suggestions from code review

* Add missing checkpoints
* Format converting scripts
* Drop (de)parallelize links from longT5 mdx

* Fix converting script + revert config file change

* Revert "Remove caching logic for local & tglobal attention"

This reverts commit 2a619828f6ddc3e65bd9bb1725a12b77fa883a46.

* Stash caching logic in Flax model

* Make side relative bias used always

* Drop caching logic in PT model

* Return side bias as it was

* Drop all remaining model parallel logic

* Remove clamp statements

* Move test files to the proper place

* Update docs with new version of hf-doc-builder

* Fix test imports

* Make some minor improvements

* Add missing checkpoints to docs
* Make TGlobal model compatible with torch.onnx.export
* Replace some np.ndarray with jnp.ndarray

* Fix TGlobal for ONNX conversion + update docs

* fix _make_global_fixed_block_ids and masked neg  value

* update flax model

* style and quality

* fix imports

* remove load_tf_weights_in_longt5 from init and fix copies

* add slow test for TGlobal model

* typo fix

* Drop obsolete is_parallelizable and one warning

* Update __init__ files to fix repo-consistency

* fix pipeline test

* Fix some device placements

* [wip]: Update tests -- need to generate summaries to update expected_summary

* Fix quality

* Update LongT5 model card

* Update (slow) summarization tests

* make style

* rename checkpoitns

* finish

* fix flax tests

Co-authored-by: phungvanduy <pvduy23@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
Daniel Stancl 2022-06-13 22:36:58 +02:00 committed by GitHub
parent 1690094bdb
commit a72f1c9f5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 7389 additions and 2 deletions

View File

@ -284,6 +284,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.

View File

@ -265,6 +265,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.

View File

@ -289,6 +289,7 @@ conda install -c huggingface transformers
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (来自 Meta AI) 伴随论文 [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) 由 Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze 发布。
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (来自 Google AI) released 伴随论文 [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) 由 Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang 发布。
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (来自 Studio Ousia) 伴随论文 [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) 由 Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto 发布。
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (来自 UNC Chapel Hill) 伴随论文 [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) 由 Hao Tan and Mohit Bansal 发布。
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (来自 Facebook) 伴随论文 [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 由 Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert 发布。

View File

@ -301,6 +301,7 @@ conda install -c huggingface transformers
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.

View File

@ -256,6 +256,8 @@
title: LeViT
- local: model_doc/longformer
title: Longformer
- local: model_doc/longt5
title: LongT5
- local: model_doc/luke
title: LUKE
- local: model_doc/lxmert

View File

@ -107,6 +107,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret
1. **[LED](model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LeViT](model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
1. **[Longformer](model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LongT5](model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
1. **[LUKE](model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
1. **[LXMERT](model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[M-CTC-T](model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
@ -233,6 +234,7 @@ Flax), PyTorch, and/or TensorFlow.
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
| LeViT | ❌ | ❌ | ✅ | ❌ | ❌ |
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
| LongT5 | ❌ | ❌ | ✅ | ❌ | ✅ |
| LUKE | ✅ | ❌ | ✅ | ❌ | ❌ |
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| M-CTC-T | ❌ | ❌ | ✅ | ❌ | ❌ |

View File

@ -0,0 +1,121 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LongT5
## Overview
The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916)
by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung and Yinfei Yang. It's an
encoder-decoder transformer pre-trained in a text-to-text denoising generative setting. LongT5 model is an extension of
T5 model, and it enables using one of the two different efficient attention mechanisms - (1) Local attention, or (2)
Transient-Global attention.
The abstract from the paper is the following:
*Recent work has shown that either (1) increasing the input length or (2) increasing model size can improve the
performance of Transformer-based neural models. In this paper, we present a new model, called LongT5, with which we
explore the effects of scaling both the input length and model size at the same time. Specifically, we integrated
attention ideas from long-input transformers (ETC), and adopted pre-training strategies from summarization pre-training
(PEGASUS) into the scalable T5 architecture. The result is a new attention mechanism we call {\em Transient Global}
(TGlobal), which mimics ETC's local/global attention mechanism, but without requiring additional side-inputs. We are
able to achieve state-of-the-art results on several summarization tasks and outperform the original T5 models on
question answering tasks.*
Tips:
- [`LongT5ForConditionalGeneration`] is an extension of [`T5ForConditionalGeneration`] exchanging the traditional
encoder *self-attention* layer with efficient either *local* attention or *transient-global* (*tglobal*) attention.
- Unlike the T5 model, LongT5 does not use a task prefix. Furthermore, it uses a different pre-training objective
inspired by the pre-training of `[PegasusForConditionalGeneration]`.
- LongT5 model is designed to work efficiently and very well on long-range *sequence-to-sequence* tasks where the
input sequence exceeds commonly used 512 tokens. It is capable of handling input sequences of a length up to 16,384 tokens.
- For *Local Attention*, the sparse sliding-window local attention operation allows a given token to attend only `r`
tokens to the left and right of it (with `r=127` by default). *Local Attention* does not introduce any new parameters
to the model. The complexity of the mechanism is linear in input sequence length `l`: `O(l*r)`.
- *Transient Global Attention* is an extension of the *Local Attention*. It, furthermore, allows each input token to
interact with all other tokens in the layer. This is achieved via splitting an input sequence into blocks of a fixed
length `k` (with a default `k=16`). Then, a global token for such a block is obtained via summing and normalizing the embeddings of every token
in the block. Thanks to this, the attention allows each token to attend to both nearby tokens like in Local attention, and
also every global token like in the case of standard global attention (*transient* represents the fact the global tokens
are constructed dynamically within each attention operation). As a consequence, *TGlobal* attention introduces
a few new parameters -- global relative position biases and a layer normalization for global token's embedding.
The complexity of this mechanism is `O(l(r + l/k))`.
- An example showing how to evaluate a fine-tuned LongT5 model on the [pubmed dataset](https://huggingface.co/datasets/scientific_papers) is below.
```python
>>> import evaluate
>>> from datasets import load_dataset
>>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
>>> dataset = load_dataset("scientific_papers", "pubmed", split="validation")
>>> model = (
... LongT5ForConditionalGeneration.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
... .to("cuda")
... .half()
... )
>>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
>>> def generate_answers(batch):
... inputs_dict = tokenizer(
... batch["article"], max_length=16384, padding="max_length", truncation=True, return_tensors="pt"
... )
... input_ids = inputs_dict.input_ids.to("cuda")
... attention_mask = inputs_dict.attention_mask.to("cuda")
... output_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=512, num_beams=2)
... batch["predicted_abstract"] = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
... return batch
>>> result = dataset.map(generate_answer, batched=True, batch_size=2)
>>> rouge = evaluate.load("rouge")
>>> rouge.compute(predictions=result["predicted_abstract"], references=result["abstract"])
```
This model was contributed by [stancld](https://huggingface.co/stancld).
The original code can be found [here](https://github.com/google-research/longt5).
## LongT5Config
[[autodoc]] LongT5Config
## LongT5Model
[[autodoc]] LongT5Model
- forward
## LongT5ForConditionalGeneration
[[autodoc]] LongT5ForConditionalGeneration
- forward
## LongT5EncoderModel
[[autodoc]] LongT5EncoderModel
- forward
## FlaxLongT5Model
[[autodoc]] FlaxLongT5Model
- __call__
- encode
- decode
## FlaxLongT5ForConditionalGeneration
[[autodoc]] FlaxLongT5ForConditionalGeneration
- __call__
- encode
- decode

View File

@ -66,6 +66,7 @@ Ready-made configurations include the following architectures:
- GPT-J
- I-BERT
- LayoutLM
- LongT5
- M2M100
- Marian
- mBART

View File

@ -239,6 +239,7 @@ _import_structure = {
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
"models.levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"],
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
"models.longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config"],
"models.luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig", "LukeTokenizer"],
"models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
@ -1277,6 +1278,15 @@ else:
"LongformerSelfAttention",
]
)
_import_structure["models.longt5"].extend(
[
"LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST",
"LongT5EncoderModel",
"LongT5ForConditionalGeneration",
"LongT5Model",
"LongT5PreTrainedModel",
]
)
_import_structure["models.luke"].extend(
[
"LUKE_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -2586,6 +2596,9 @@ else:
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
)
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
_import_structure["models.longt5"].extend(
["FlaxLongT5ForConditionalGeneration", "FlaxLongT5Model", "FlaxLongT5PreTrainedModel"]
)
_import_structure["models.marian"].extend(
[
"FlaxMarianModel",
@ -2850,6 +2863,7 @@ if TYPE_CHECKING:
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
from .models.levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
from .models.longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config
from .models.luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig, LukeTokenizer
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
@ -3726,6 +3740,13 @@ if TYPE_CHECKING:
LongformerPreTrainedModel,
LongformerSelfAttention,
)
from .models.longt5 import (
LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST,
LongT5EncoderModel,
LongT5ForConditionalGeneration,
LongT5Model,
LongT5PreTrainedModel,
)
from .models.luke import (
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST,
LukeForEntityClassification,
@ -4784,6 +4805,7 @@ if TYPE_CHECKING:
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
from .models.gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
from .models.longt5 import FlaxLongT5ForConditionalGeneration, FlaxLongT5Model, FlaxLongT5PreTrainedModel
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
from .models.mbart import (
FlaxMBartForConditionalGeneration,

View File

@ -76,6 +76,7 @@ from . import (
led,
levit,
longformer,
longt5,
luke,
lxmert,
m2m_100,

View File

@ -78,6 +78,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("led", "LEDConfig"),
("levit", "LevitConfig"),
("longformer", "LongformerConfig"),
("longt5", "LongT5Config"),
("luke", "LukeConfig"),
("lxmert", "LxmertConfig"),
("m2m_100", "M2M100Config"),
@ -193,6 +194,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("levit", "LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("longt5", "LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@ -310,6 +312,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("led", "LED"),
("levit", "LeViT"),
("longformer", "Longformer"),
("longt5", "LongT5"),
("luke", "LUKE"),
("lxmert", "LXMERT"),
("m2m_100", "M2M100"),

View File

@ -77,6 +77,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("led", "LEDModel"),
("levit", "LevitModel"),
("longformer", "LongformerModel"),
("longt5", "LongT5Model"),
("luke", "LukeModel"),
("lxmert", "LxmertModel"),
("m2m_100", "M2M100Model"),
@ -217,6 +218,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("layoutlm", "LayoutLMForMaskedLM"),
("led", "LEDForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("longt5", "LongT5ForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"),
("megatron-bert", "MegatronBertForCausalLM"),
@ -423,6 +425,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("encoder-decoder", "EncoderDecoderModel"),
("fsmt", "FSMTForConditionalGeneration"),
("led", "LEDForConditionalGeneration"),
("longt5", "LongT5ForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"),
("mbart", "MBartForConditionalGeneration"),

View File

@ -41,6 +41,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
("gptj", "FlaxGPTJModel"),
("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
("mt5", "FlaxMT5Model"),
@ -65,6 +66,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("bert", "FlaxBertForPreTraining"),
("big_bird", "FlaxBigBirdForPreTraining"),
("electra", "FlaxElectraForPreTraining"),
("longt5", "FlaxLongT5ForConditionalGeneration"),
("mbart", "FlaxMBartForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("roberta", "FlaxRobertaForMaskedLM"),
@ -98,6 +100,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
("longt5", "FlaxLongT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
("mbart", "FlaxMBartForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),

View File

@ -137,6 +137,13 @@ else:
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
(
"longt5",
(
"T5Tokenizer" if is_sentencepiece_available() else None,
"T5TokenizerFast" if is_tokenizers_available() else None,
),
),
("luke", ("LukeTokenizer", None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),

View File

@ -0,0 +1,88 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
_import_structure = {
"configuration_longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config", "LongT5OnnxConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_longt5"] = [
"LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST",
"LongT5EncoderModel",
"LongT5ForConditionalGeneration",
"LongT5Model",
"LongT5PreTrainedModel",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_longt5"] = [
"FlaxLongT5ForConditionalGeneration",
"FlaxLongT5Model",
"FlaxLongT5PreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config, LongT5OnnxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_longt5 import (
LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST,
LongT5EncoderModel,
LongT5ForConditionalGeneration,
LongT5Model,
LongT5PreTrainedModel,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_longt5 import (
FlaxLongT5ForConditionalGeneration,
FlaxLongT5Model,
FlaxLongT5PreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,178 @@
# coding=utf-8
# Copyright 2022, The LongT5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LongT5 model configuration"""
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxSeq2SeqConfigWithPast
from ...utils import logging
logger = logging.get_logger(__name__)
LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/LongT5-Local-Base": "https://huggingface.co/google/LongT5-Local-Base/blob/main/config.json",
"google/LongT5-Local-Large": "https://huggingface.co/google/LongT5-Local-Large/blob/main/config.json",
"google/LongT5-TGlobal-Base": "https://huggingface.co/google/LongT5-TGlobal-Base/blob/main/config.json",
"google/LongT5-TGlobal-Large": "https://huggingface.co/google/LongT5-TGlobal-Large/blob/main/config.json",
}
class LongT5Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LongT5Model`] or a [`FlaxLongT5Model`]. It is
used to instantiate a LongT5 model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the LongT5
[google/LongT5-Local-Base](https://huggingface.co/google/LongT5-Local-Base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Arguments:
vocab_size (`int`, *optional*, defaults to 32128):
Vocabulary size of the LongT5 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LongT5Model`].
d_model (`int`, *optional*, defaults to 512):
Size of the encoder layers and the pooler layer.
d_kv (`int`, *optional*, defaults to 64):
Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
num_heads`.
d_ff (`int`, *optional*, defaults to 2048):
Size of the intermediate feed forward layer in each `LongT5Block`.
num_layers (`int`, *optional*, defaults to 6):
Number of hidden layers in the Transformer encoder.
num_decoder_layers (`int`, *optional*):
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
num_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
local_radius (`int`, *optional*, defaults to 127)
Number of tokens to the left/right for each token to locally self-attend in a local attention mechanism.
global_block_size (`int`, *optional*, defaults to 16)
Lenght of blocks an input sequence is divided into for a global token representation. Used only for
`encoder_attention_type = "transient-global"`.
relative_attention_num_buckets (`int`, *optional*, defaults to 32):
The number of buckets to use for each attention layer.
relative_attention_max_distance (`int`, *optional*, defaults to 128):
The maximum distance of the longer sequences for the bucket separation.
dropout_rate (`float`, *optional*, defaults to 0.1):
The ratio for all dropout layers.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
initializer_factor (`float`, *optional*, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. LongT5v1.1 uses the
`"gated-gelu"` feed forward projection. Original LongT5 implementation uses `"gated-gelu"`.
encoder_attention_type (`string`, *optional*, defaults to `"local"`):
Type of encoder attention to be used. Should be one of `"local"` or `"transient-global"`, which are
supported by LongT5 implementation.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
"""
model_type = "longt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
def __init__(
self,
vocab_size=32128,
d_model=512,
d_kv=64,
d_ff=2048,
num_layers=6,
num_decoder_layers=None,
num_heads=8,
local_radius=127,
global_block_size=16,
relative_attention_num_buckets=32,
relative_attention_max_distance=128,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
feed_forward_proj="relu",
is_encoder_decoder=True,
encoder_attention_type="local",
use_cache=True,
pad_token_id=0,
eos_token_id=1,
**kwargs
):
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
# default = symmetry
self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers
self.num_heads = num_heads
self.local_radius = local_radius
self.global_block_size = global_block_size
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.encoder_attention_type = encoder_attention_type
self.use_cache = use_cache
act_info = self.feed_forward_proj.split("-")
self.dense_act_fn = act_info[-1]
self.is_gated_act = act_info[0] == "gated"
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
raise ValueError(
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# for backwards compatibility
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
class LongT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = {
"input_ids": {0: "batch", 1: "encoder_sequence"},
"attention_mask": {0: "batch", 1: "encoder_sequence"},
}
if self.use_past:
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
return common_inputs
@property
def default_onnx_opset(self) -> int:
return 13

View File

@ -0,0 +1,214 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert T5/LongT5X checkpoints from the original repository to JAX/FLAX model. This script is an extension of
'src/transformers/models/t5/convert_t5x_checkpoint_to_flax.
"""
import argparse
from t5x import checkpoints
from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM
def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
config = AutoConfig.from_pretrained(config_name)
flax_model = FlaxAutoModelForSeq2SeqLM.from_config(config=config)
t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
if config.model_type == "t5":
encoder_attn_name = "SelfAttention"
if config.model_type == "longt5" and config.encoder_attention_type == "local":
encoder_attn_name = "LocalSelfAttention"
elif config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
encoder_attn_name = "TransientGlobalSelfAttention"
else:
raise ValueError(
"Given config is expected to have `model_type='t5'`, or `model_type='longt5` with `encoder_attention_type`"
" attribute with a value from ['local', 'transient-global]."
)
# Encoder
for layer_index in range(config.num_layers):
layer_name = f"layers_{str(layer_index)}"
# Self-Attention
t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
# Global input layer norm
if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
t5x_global_layer_norm = t5x_model["target"]["encoder"][layer_name]["attention"]["T5LayerNorm_0"]["scale"]
# Layer Normalization
t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
if split_mlp_wi:
t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
else:
t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
# Layer Normalization
t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
# Assigning
flax_model_encoder_layer_block = flax_model.params["encoder"]["block"][str(layer_index)]["layer"]
flax_model_encoder_layer_block["0"][encoder_attn_name]["k"]["kernel"] = t5x_attention_key
flax_model_encoder_layer_block["0"][encoder_attn_name]["o"]["kernel"] = t5x_attention_out
flax_model_encoder_layer_block["0"][encoder_attn_name]["q"]["kernel"] = t5x_attention_query
flax_model_encoder_layer_block["0"][encoder_attn_name]["v"]["kernel"] = t5x_attention_value
flax_model_encoder_layer_block["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
# Global input layer norm
if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
flax_model_encoder_layer_block["0"][encoder_attn_name]["global_input_layer_norm"][
"weight"
] = t5x_global_layer_norm
if split_mlp_wi:
flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
else:
flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
flax_model_encoder_layer_block["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
flax_model_encoder_layer_block["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
flax_model.params["encoder"]["block"][str(layer_index)]["layer"] = flax_model_encoder_layer_block
# Only for layer 0:
t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["relative_attention_bias"][
"embedding"
] = t5x_encoder_rel_embedding
# Side/global relative position_bias + layer norm
if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
t5x_encoder_global_rel_embedding = t5x_model["target"]["encoder"]["side_relpos_bias"]["rel_embedding"].T
flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["global_relative_attention_bias"][
"embedding"
] = t5x_encoder_global_rel_embedding
# Assigning
t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
# Decoder
for layer_index in range(config.num_layers):
layer_name = f"layers_{str(layer_index)}"
# Self-Attention
t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
# Layer Normalization
t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][
"scale"
]
# Encoder-Decoder-Attention
t5x_enc_dec_attention_module = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]
t5x_enc_dec_attention_key = t5x_enc_dec_attention_module["key"]["kernel"]
t5x_enc_dec_attention_out = t5x_enc_dec_attention_module["out"]["kernel"]
t5x_enc_dec_attention_query = t5x_enc_dec_attention_module["query"]["kernel"]
t5x_enc_dec_attention_value = t5x_enc_dec_attention_module["value"]["kernel"]
# Layer Normalization
t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
# MLP
if split_mlp_wi:
t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
else:
t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
# Layer Normalization
tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
# Assigning
flax_model_decoder_layer_block = flax_model.params["decoder"]["block"][str(layer_index)]["layer"]
flax_model_decoder_layer_block["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
flax_model_decoder_layer_block["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
flax_model_decoder_layer_block["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
flax_model_decoder_layer_block["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
flax_model_decoder_layer_block["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
flax_model_decoder_layer_block["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
flax_model_decoder_layer_block["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
flax_model_decoder_layer_block["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
flax_model_decoder_layer_block["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
flax_model_decoder_layer_block["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
if split_mlp_wi:
flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
else:
flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
flax_model_decoder_layer_block["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
flax_model_decoder_layer_block["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
flax_model.params["decoder"]["block"][str(layer_index)]["layer"] = flax_model_decoder_layer_block
# Decoder Normalization
tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
# Only for layer 0:
t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][
"embedding"
] = t5x_decoder_rel_embedding
# Token Embeddings
tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
flax_model.params["shared"]["embedding"] = tx5_token_embeddings
# LM Head (only in v1.1 and LongT5 checkpoints)
if "logits_dense" in t5x_model["target"]["decoder"]:
flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
flax_model.save_pretrained(flax_dump_folder_path)
print("T5X Model was sucessfully converted!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the T5X checkpoint."
)
parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of LongT5/T5 model.")
parser.add_argument(
"--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
)
args = parser.parse_args()
convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -281,6 +281,13 @@ class FeaturesManager:
"token-classification",
onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
),
"longt5": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.longt5.LongT5OnnxConfig",
),
"marian": supported_features_mapping(
"default",
"default-with-past",

View File

@ -725,6 +725,27 @@ class FlaxGPTJPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxLongT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLongT5Model(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLongT5PreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxMarianModel(metaclass=DummyObject):
_backends = ["flax"]

View File

@ -2605,6 +2605,37 @@ class LongformerSelfAttention(metaclass=DummyObject):
requires_backends(self, ["torch"])
LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST = None
class LongT5EncoderModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LongT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LongT5Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LongT5PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = None

9
test_results.txt Normal file
View File

@ -0,0 +1,9 @@
background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred

View File

View File

@ -0,0 +1,757 @@
# coding=utf-8
# Copyright 2022 Google LongT5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
import transformers
from transformers import is_flax_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
is_pt_flax_cross_test,
require_flax,
require_sentencepiece,
require_tokenizers,
slow,
)
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import os
# The slow tests are often failing with OOM error on GPU
# This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
# but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING, AutoTokenizer, LongT5Config
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.longt5.modeling_flax_longt5 import (
FlaxLongT5ForConditionalGeneration,
FlaxLongT5Model,
shift_tokens_right,
)
class FlaxLongT5ModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
encoder_seq_length=7,
decoder_seq_length=9,
local_radius=5,
encoder_attention_type="local",
global_block_size=3,
# For common tests
is_training=True,
use_attention_mask=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
d_ff=37,
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_id=1,
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
decoder_layers=None,
):
self.parent = parent
self.batch_size = batch_size
self.encoder_seq_length = encoder_seq_length
self.decoder_seq_length = decoder_seq_length
self.local_radius = local_radius
self.block_len = local_radius + 1
self.encoder_attention_type = encoder_attention_type
self.global_block_size = global_block_size
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.d_ff = d_ff
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.scope = None
self.decoder_layers = decoder_layers
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
attention_mask = None
decoder_attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
config = LongT5Config(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
num_decoder_layers=self.decoder_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
local_radius=self.local_radius,
encoder_attention_type=self.encoder_attention_type,
global_block_size=self.global_block_size,
)
return (
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
)
def create_and_check_model(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
):
model = FlaxLongT5Model(config=config)
result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
decoder_output = result.last_hidden_state
encoder_output = result.encoder_last_hidden_state
self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size))
self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.decoder_seq_length, self.hidden_size))
def check_use_cache_forward_with_attn_mask(
self,
model_class_name,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
):
max_decoder_length = 20
model = model_class_name(config)
encoder_outputs = model.encode(input_ids)
# prevent fully zero'd out attention mask
decoder_attention_mask = jnp.ones_like(decoder_attention_mask)
decoder_attention_mask_cache = jnp.concatenate(
[
decoder_attention_mask,
jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
],
axis=-1,
)
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
outputs_cache = model.decode(
decoder_input_ids[:, :-1],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask_cache,
past_key_values=past_key_values,
)
outputs_cache_next = model.decode(
decoder_input_ids[:, -1:],
encoder_outputs,
past_key_values=outputs_cache.past_key_values,
decoder_attention_mask=decoder_attention_mask_cache,
)
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return config, inputs_dict
@require_flax
class FlaxLongT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
all_model_classes = (FlaxLongT5Model, FlaxLongT5ForConditionalGeneration) if is_flax_available() else ()
all_generative_model_classes = (FlaxLongT5ForConditionalGeneration,) if is_flax_available() else ()
is_encoder_decoder = True
def setUp(self):
self.model_tester = FlaxLongT5ModelTester(self)
self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# check that gated gelu feed forward and different word embeddings work
config = config_and_inputs[0]
config.tie_word_embeddings = False
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
def test_use_cache_forward_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs)
def test_encode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def encode_jitted(input_ids, attention_mask=None, **kwargs):
return model.encode(input_ids=input_ids, attention_mask=attention_mask)
with self.subTest("JIT Enabled"):
jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def test_decode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
model = model_class(config)
encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])
prepared_inputs_dict = {
"decoder_input_ids": inputs_dict["decoder_input_ids"],
"decoder_attention_mask": inputs_dict["decoder_attention_mask"],
"encoder_outputs": encoder_outputs,
}
@jax.jit
def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs):
return model.decode(
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
)
with self.subTest("JIT Enabled"):
jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def test_shift_right(self):
decoder_start_token_id = 0
pad_token_id = 1
labels = np.arange(2, 102).reshape(5, 20)
labels[:2, 15:] = -100
decoder_input_ids = shift_tokens_right(labels, pad_token_id, decoder_start_token_id)
np_decoder_input_ids = np.array(decoder_input_ids)
padded_slice = np_decoder_input_ids[:2, (15 + 1) :]
self.assertTrue((padded_slice == 1).all())
not_padded_slice = np_decoder_input_ids[2:, 1:]
rolled_labels = np.roll(labels[2:], 1)[:, 1:]
self.assertTrue((not_padded_slice == rolled_labels).all())
self.assertTrue((np_decoder_input_ids[:, 0] == 0).all())
# overwrite since special base model prefix is used
def test_save_load_from_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
def test_save_load_to_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_length = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
block_len = getattr(self.model_tester, "block_len", None)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
)
out_len = len(outputs)
if self.is_encoder_decoder:
correct_outlen = 5
# Question Answering model returns start_logits and end_logits
if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_key_length,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
)
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
class FlaxLongT5TGlobalModelTest(FlaxLongT5ModelTest):
def setUp(self):
self.model_tester = FlaxLongT5ModelTester(self, encoder_attention_type="transient-global")
self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_length = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
block_len = getattr(self.model_tester, "block_len", None)
global_block_size = getattr(self.model_tester, "global_block_size", None)
global_seq_len = encoder_seq_length // global_block_size
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
)
out_len = len(outputs)
if self.is_encoder_decoder:
correct_outlen = 5
# Question Answering model returns start_logits and end_logits
if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_key_length,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
)
@require_sentencepiece
@require_tokenizers
@require_flax
class FlaxLongT5ModelIntegrationTests(unittest.TestCase):
model_path = "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
def expected_summary(self):
return [
"background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in"
" developing world . it provides an excellent resolution for visualization of the coronary arteries for"
" catheter - based or operating interventions . although the association of this technique with major"
" complications such as mortality is highly uncommon , it is frequently associated with various cardiac"
" and noncardiac complications . computed tomography coronary angiography is a promising technique for the"
" evaluation of cad noninvasively . it assesses disease within the coronary artery and provides"
" qualitative and quantitative information about nonobstructive atherosclerotic plaque"
]
@slow
def test_summarization(self):
model = FlaxLongT5ForConditionalGeneration.from_pretrained(self.model_path)
tok = AutoTokenizer.from_pretrained(self.model_path)
ARTICLE = """coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . \n it provides an excellent resolution for visualization of the coronary arteries for catheter - based or operating interventions . \n
although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications . computed tomography ( ct ) coronary angiography is
a promising technique for the evaluation of cad noninvasively . \n it assesses disease within the coronary artery and provides qualitative and quantitative information about nonobstructive atherosclerotic plaque burden within the vessel
wall . \n thus , ct angiography - based disease evaluation may provide clinically more significant information than conventional angiography . the introduction of multi - slice computed tomography ( msct ) technology such as 64-slice , 12
8-slice , 256-slice , and now 320-slice msct has produced a high diagnostic accuracy of ct coronary angiography . \n it has consistently showed to have a very high negative predictive value ( well above 90% ) in ruling out patients with s
ignificant cad defined as coronary luminal stenosis of > 50% . \n the american college of cardiology / american heart association recommends that coronary angiography should be performed before valve surgery in men aged > 40 years , women
aged > 35 years with coronary risk factors and in postmenopausal women . \n the prevalence of cad in patients undergoing valve replacement is 2040% in developed countries . in the previous studies , \n the incidence of angiographically p
roven cad in acquired valvular diseases has been shown to vary widely from 9% to 41% . in aortic stenosis , \n we aimed to report the diagnostic performance of 128-slice ct coronary angiography in 50 patients undergoing for major noncoron
ary cardiac surgery referred for diagnostic invasive coronary angiography to assess the extent and severity of coronary stenosis . \n during january 2013 to december 2014 , we enrolled fifty major noncoronary cardiac surgery patients sche
duled for invasive coronary angiography who fulfilled the following inclusion criteria of age 40 years , having low or intermediate probability of cad , left ventricular ejection fraction ( lvef ) > 35% , and patient giving informed conse
nt for undergoing msct and conventional coronary angiography . \n those having any contraindication for contrast injection , lvef < 35% , high pretest probability of cad , and hemodynamic instability were excluded from the study . \n pati
ents with heart rates of > 70 bpm received ( unless they had known overt heart failure or electrocardiogram ( ecg ) atrioventricular conduction abnormalities ) a single oral dose of 100 mg metoprolol 45 min before the scan . \n patients w
ith heart rates of > 80 bpm received an additional oral dose of metoprolol if not contraindicated . \n all patients were scanned with a 128-slice ct scanner ( siemens , somatom definition as ) equipped with a new feature in msct technolog
y , so - called z - axis flying - focus technology . \n the central 32 detector rows acquire 0.6-mm slices , and the flying - focus spot switches back and forth between 2 z positions between each reading . \n two slices per detector row a
re acquired , which results in a higher oversampling rate in the z - axis , thereby reducing artifacts related to the spiral acquisition and improving spatial resolution down to 0.4 mm . \n a bolus of 6580 ml contrast material ( omnipaque
) was injected through an arm vein at a flow rate of 5 ml / s . \n a bolus tracking technique was used to synchronize the arrival of contrast in the coronary arteries with the initiation of the scan . to monitor the arrival of contrast m
aterial , \n axial scans were obtained at the level of the ascending aorta with a delay of 10 s after the start of the contrast injection . \n the scan was automatically started when a threshold of 150 hounsfield units was reached in a re
gion of interest positioned in the ascending aorta . \n images were reconstructed with ecg gating to obtain optimal , motion - free image quality . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a s
ingle observer unaware of the multi - slice ct results identified coronary lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiograp
hy . \n lesions were classified as having nonsignificant disease ( luminal irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean
lumen diameter reduction was 50% using a validated quantitative coronary angiography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiograp
hy . \n total calcium scores of all patients were calculated with dedicated software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of th
e number , areas , and peak hounsfield units of the detected calcified lesions . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were
used to identify coronary lesions and ( curved ) multiplanar reconstructions to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the di
agnostic performance of ct coronary angiography for the detection of significant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and
positive and negative likelihood ratios with the corresponding exact 95% of confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease p
er vessel ) , and patient by patient ( no or any disease per patient ) . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a single observer unaware of the multi - slice ct results identified coronary
lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiography . \n lesions were classified as having nonsignificant disease ( luminal
irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean lumen diameter reduction was 50% using a validated quantitative coronary an
giography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiography . \n total calcium scores of all patients were calculated with dedicated
software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of the number , areas , and peak hounsfield units of the detected calcified lesi
ons . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were used to identify coronary lesions and ( curved ) multiplanar reconstruction
s to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the diagnostic performance of ct coronary angiography for the detection of signif
icant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and positive and negative likelihood ratios with the corresponding exact 95% of
confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease per vessel ) , and patient by patient ( no or any disease per patient ) . \n
in this study , 29 ( 58% ) subjects were female , and 21 ( 42% ) were male showing an average age of 50.36 8.39 years . \n of fifty patients 24 ( 48% ) , 13 ( 26% ) , eight ( 16% ) , and five ( 10% ) underwent mitral valve replacement ,
double valve replacement ( dvr ) , aortic valve replacement , and other surgeries , respectively . \n high distribution of cad risk factors such as hypertension ( 24% ) , smoking ( 22% ) , and dyslipidemia ( 18% ) was observed in the stu
dy group . \n the mean creatinine level was 0.766 0.17 and average dye used in conventional angiography was 48.5 26.6 whereas for ct angiography it was 72.8 6.32 . \n average radiation dose in conventional coronary angiography and msct
coronary angiography was 5.2 msv and 9.2 msv , respectively . \n the majority of the patients had sinus rhythm ( 68% ) , whereas atrial fibrillation was found in 32% of the subjects . \n patients included in the study had low to intermed
iate probability of cad . in this study , three patients had complications after conventional angiography . \n complications were of local site hematoma , acute kidney injury managed conservatively , and acute heart failure . \n a patient
who developed hematoma was obese female patients with body mass index > 30 kg / m . \n the patient suffered from pseudoaneurysm , had hospitalized for 9 days , which leads to increased morbidity and cost of hospital stay . \n the diagnos
tic accuracy of ct coronary angiography was evaluated regarding true positive , true negative values and is presented in table 1 . the overall sensitivity and \n specificity of ct angiography technique was 100% ( 95% ci : 39.76%100% ) and
91.30% ( 95% ci : 79.21%97.58% ) , respectively [ table 2 ] . \n the positive predictive value ( 50% ; 95% ci : 15.70%84.30% ) and negative predictive value ( 100% ; 95% ci : 91.59%100% ) of ct angiography were also fairly high in these
patients . \n recent reports from multiple studies demonstrated that recent - generation msct scanners showed promise for noninvasive detection of coronary stenosis however , until now no studies were found regarding the clinical efficacy
or prognostic value of 128-slice ct coronary angiography versus conventional invasive coronary angiography in the diagnosis of patients planned for major noncoronary surgeries such as dvr , bentall , atrial septal defect closure , etc .
in our study , we reported 8% cad prevalence in patients planned for major noncoronary cardiac surgery . \n we performed conventional and msct coronary angiography in all patients and the results showed that ct coronary angiography with i
nvasive coronary angiography as the reference standard had a considerably high sensitivity ( 100% ) and specificity ( 95.65% ) . \n the health economic model using invasive coronary angiography as the reference standard showed that at a p
retest probability of cad of 70% or lower , ct coronary angiography resulted in lower cost per patient with a true positive diagnosis . at a pretest probability of cad of 70% or higher , invasive coronary angiography was associated with a
lower cost per patient with a true positive diagnosis . in our study population , \n two patients developed local site complications in the form of hematoma and pseudoaneurysm after conventional angiography . \n hence , msct coronary ang
iography will be more favorable in female obese patients with intermediate likelihood of cad . \n hence , msct coronary angiography will be cost - effective in patients of valvular heart diseases . \n however , ct angiography suffers from
a drawback that average amount of dye used in msct coronary angiography were 72.8 6.32 ml which is higher than average amount of dye required for conventional angiography ( 48.6 26.6 ml ) . \n hence , the use of ct coronary angiography
could not be used in patients with known renal dysfunction , where reduction of contrast dye load is highly advocated . \n our results show that 128-slice ct coronary angiography is a reliable technique to detect coronary stenosis in pat
ients planned for noncoronary cardiac surgery . \n although there has been important technological progress in the development of ct coronary angiography , its clinical application remains limited . \n a study wth large numbers of patient
s is required for the recommendation of only ct coronary angiography for the coronary evaluation in major non - cardiac surgeries . \n mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , guja
rat , india ) . \n u.n . mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , gujarat , india ) . \n """
dct = tok(
[ARTICLE],
max_length=1024,
padding="max_length",
truncation=True,
return_tensors="np",
)
hypotheses_batch = model.generate(
**dct,
num_beams=4,
length_penalty=2.0,
max_length=142,
min_length=56,
do_sample=False,
early_stopping=True,
).sequences
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertListEqual(
self.expected_summary(),
decoded,
)

File diff suppressed because it is too large Load Diff

View File

@ -213,6 +213,8 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("blenderbot-small", "facebook/blenderbot_small-90M"),
("blenderbot", "facebook/blenderbot-400M-distill"),
("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"),
("longt5", "google/LongT5-Local-Base"),
("longt5", "google/LongT5-TGlobal-Base"),
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.

View File

@ -18,6 +18,7 @@ from transformers import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
LEDConfig,
LongT5Config,
SummarizationPipeline,
T5Config,
pipeline,
@ -54,8 +55,8 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
)
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
if not isinstance(model.config, (T5Config, LEDConfig)):
# LED, T5 can handle it.
if not isinstance(model.config, (T5Config, LongT5Config, LEDConfig)):
# LED, T5, LongT5 can handle it.
# Too long.
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)

View File

@ -36,6 +36,7 @@ PATH_TO_DOC = "docs/source/en"
# Update this list with models that are supposed to be private.
PRIVATE_MODELS = [
"DPRSpanPredictor",
"LongT5Stack",
"RealmBertModel",
"T5Stack",
"TFDPRSpanPredictor",

View File

@ -36,6 +36,7 @@ src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
src/transformers/models/longformer/modeling_longformer.py
src/transformers/models/longformer/modeling_tf_longformer.py
src/transformers/models/longt5/modeling_longt5.py
src/transformers/models/marian/modeling_marian.py
src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mobilebert/modeling_mobilebert.py