mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add LongT5
model (#16792)
* Initial commit * Make some fixes * Make PT model full forward pass * Drop TF & Flax implementation, fix copies etc * Add Flax model and update some corresponding stuff * Drop some TF things * Update config and flax local attn * Add encoder_attention_type to config * . * Update docs * Do some cleansing * Fix some issues -> make style; add some docs * Fix position_bias + mask addition + Update tests * Fix repo consistency * Fix model consistency by removing flax operation over attn_mask * [WIP] Add PT TGlobal LongT5 * . * [WIP] Add flax tglobal model * [WIP] Update flax model to use the right attention type in the encoder * Fix flax tglobal model forward pass * Make the use of global_relative_attention_bias * Add test suites for TGlobal model * Fix minor bugs, clean code * Fix pt-flax equivalence though not convinced with correctness * Fix LocalAttn implementation to match the original impl. + update READMEs * Few updates * Update: [Flax] improve large model init and loading #16148 * Add ckpt conversion script accoring to #16853 + handle torch device placement * Minor updates to conversion script. * Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM * gpu support + dtype fix * Apply some suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * * Remove (de)parallelize stuff * Edit shape comments * Update README.md * make fix-copies * Remove caching logic for local & tglobal attention * Apply another batch of suggestions from code review * Add missing checkpoints * Format converting scripts * Drop (de)parallelize links from longT5 mdx * Fix converting script + revert config file change * Revert "Remove caching logic for local & tglobal attention" This reverts commit 2a619828f6ddc3e65bd9bb1725a12b77fa883a46. * Stash caching logic in Flax model * Make side relative bias used always * Drop caching logic in PT model * Return side bias as it was * Drop all remaining model parallel logic * Remove clamp statements * Move test files to the proper place * Update docs with new version of hf-doc-builder * Fix test imports * Make some minor improvements * Add missing checkpoints to docs * Make TGlobal model compatible with torch.onnx.export * Replace some np.ndarray with jnp.ndarray * Fix TGlobal for ONNX conversion + update docs * fix _make_global_fixed_block_ids and masked neg value * update flax model * style and quality * fix imports * remove load_tf_weights_in_longt5 from init and fix copies * add slow test for TGlobal model * typo fix * Drop obsolete is_parallelizable and one warning * Update __init__ files to fix repo-consistency * fix pipeline test * Fix some device placements * [wip]: Update tests -- need to generate summaries to update expected_summary * Fix quality * Update LongT5 model card * Update (slow) summarization tests * make style * rename checkpoitns * finish * fix flax tests Co-authored-by: phungvanduy <pvduy23@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
parent
1690094bdb
commit
a72f1c9f5b
@ -284,6 +284,7 @@ Current number of checkpoints: ** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
|
||||
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
|
||||
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
|
||||
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
||||
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
|
||||
|
@ -265,6 +265,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
|
||||
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
|
||||
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
|
||||
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
||||
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
|
||||
|
@ -289,6 +289,7 @@ conda install -c huggingface transformers
|
||||
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。
|
||||
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (来自 Meta AI) 伴随论文 [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) 由 Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze 发布。
|
||||
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。
|
||||
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (来自 Google AI) released 伴随论文 [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) 由 Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang 发布。
|
||||
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (来自 Studio Ousia) 伴随论文 [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) 由 Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto 发布。
|
||||
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (来自 UNC Chapel Hill) 伴随论文 [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) 由 Hao Tan and Mohit Bansal 发布。
|
||||
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (来自 Facebook) 伴随论文 [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 由 Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert 发布。
|
||||
|
@ -301,6 +301,7 @@ conda install -c huggingface transformers
|
||||
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
|
||||
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
|
||||
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
|
||||
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
||||
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
|
||||
|
@ -256,6 +256,8 @@
|
||||
title: LeViT
|
||||
- local: model_doc/longformer
|
||||
title: Longformer
|
||||
- local: model_doc/longt5
|
||||
title: LongT5
|
||||
- local: model_doc/luke
|
||||
title: LUKE
|
||||
- local: model_doc/lxmert
|
||||
|
@ -107,6 +107,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret
|
||||
1. **[LED](model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LeViT](model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
|
||||
1. **[Longformer](model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LongT5](model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
|
||||
1. **[LUKE](model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
|
||||
1. **[LXMERT](model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
||||
1. **[M-CTC-T](model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
|
||||
@ -233,6 +234,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| LeViT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| LongT5 | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
| LUKE | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| M-CTC-T | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
121
docs/source/en/model_doc/longt5.mdx
Normal file
121
docs/source/en/model_doc/longt5.mdx
Normal file
@ -0,0 +1,121 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LongT5
|
||||
|
||||
## Overview
|
||||
|
||||
The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916)
|
||||
by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung and Yinfei Yang. It's an
|
||||
encoder-decoder transformer pre-trained in a text-to-text denoising generative setting. LongT5 model is an extension of
|
||||
T5 model, and it enables using one of the two different efficient attention mechanisms - (1) Local attention, or (2)
|
||||
Transient-Global attention.
|
||||
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Recent work has shown that either (1) increasing the input length or (2) increasing model size can improve the
|
||||
performance of Transformer-based neural models. In this paper, we present a new model, called LongT5, with which we
|
||||
explore the effects of scaling both the input length and model size at the same time. Specifically, we integrated
|
||||
attention ideas from long-input transformers (ETC), and adopted pre-training strategies from summarization pre-training
|
||||
(PEGASUS) into the scalable T5 architecture. The result is a new attention mechanism we call {\em Transient Global}
|
||||
(TGlobal), which mimics ETC's local/global attention mechanism, but without requiring additional side-inputs. We are
|
||||
able to achieve state-of-the-art results on several summarization tasks and outperform the original T5 models on
|
||||
question answering tasks.*
|
||||
|
||||
Tips:
|
||||
|
||||
- [`LongT5ForConditionalGeneration`] is an extension of [`T5ForConditionalGeneration`] exchanging the traditional
|
||||
encoder *self-attention* layer with efficient either *local* attention or *transient-global* (*tglobal*) attention.
|
||||
- Unlike the T5 model, LongT5 does not use a task prefix. Furthermore, it uses a different pre-training objective
|
||||
inspired by the pre-training of `[PegasusForConditionalGeneration]`.
|
||||
- LongT5 model is designed to work efficiently and very well on long-range *sequence-to-sequence* tasks where the
|
||||
input sequence exceeds commonly used 512 tokens. It is capable of handling input sequences of a length up to 16,384 tokens.
|
||||
- For *Local Attention*, the sparse sliding-window local attention operation allows a given token to attend only `r`
|
||||
tokens to the left and right of it (with `r=127` by default). *Local Attention* does not introduce any new parameters
|
||||
to the model. The complexity of the mechanism is linear in input sequence length `l`: `O(l*r)`.
|
||||
- *Transient Global Attention* is an extension of the *Local Attention*. It, furthermore, allows each input token to
|
||||
interact with all other tokens in the layer. This is achieved via splitting an input sequence into blocks of a fixed
|
||||
length `k` (with a default `k=16`). Then, a global token for such a block is obtained via summing and normalizing the embeddings of every token
|
||||
in the block. Thanks to this, the attention allows each token to attend to both nearby tokens like in Local attention, and
|
||||
also every global token like in the case of standard global attention (*transient* represents the fact the global tokens
|
||||
are constructed dynamically within each attention operation). As a consequence, *TGlobal* attention introduces
|
||||
a few new parameters -- global relative position biases and a layer normalization for global token's embedding.
|
||||
The complexity of this mechanism is `O(l(r + l/k))`.
|
||||
- An example showing how to evaluate a fine-tuned LongT5 model on the [pubmed dataset](https://huggingface.co/datasets/scientific_papers) is below.
|
||||
|
||||
```python
|
||||
>>> import evaluate
|
||||
>>> from datasets import load_dataset
|
||||
>>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
|
||||
|
||||
>>> dataset = load_dataset("scientific_papers", "pubmed", split="validation")
|
||||
>>> model = (
|
||||
... LongT5ForConditionalGeneration.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
|
||||
... .to("cuda")
|
||||
... .half()
|
||||
... )
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
|
||||
|
||||
|
||||
>>> def generate_answers(batch):
|
||||
... inputs_dict = tokenizer(
|
||||
... batch["article"], max_length=16384, padding="max_length", truncation=True, return_tensors="pt"
|
||||
... )
|
||||
... input_ids = inputs_dict.input_ids.to("cuda")
|
||||
... attention_mask = inputs_dict.attention_mask.to("cuda")
|
||||
... output_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=512, num_beams=2)
|
||||
... batch["predicted_abstract"] = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
... return batch
|
||||
|
||||
|
||||
>>> result = dataset.map(generate_answer, batched=True, batch_size=2)
|
||||
>>> rouge = evaluate.load("rouge")
|
||||
>>> rouge.compute(predictions=result["predicted_abstract"], references=result["abstract"])
|
||||
```
|
||||
|
||||
This model was contributed by [stancld](https://huggingface.co/stancld).
|
||||
The original code can be found [here](https://github.com/google-research/longt5).
|
||||
|
||||
|
||||
## LongT5Config
|
||||
|
||||
[[autodoc]] LongT5Config
|
||||
|
||||
## LongT5Model
|
||||
|
||||
[[autodoc]] LongT5Model
|
||||
- forward
|
||||
|
||||
## LongT5ForConditionalGeneration
|
||||
|
||||
[[autodoc]] LongT5ForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## LongT5EncoderModel
|
||||
|
||||
[[autodoc]] LongT5EncoderModel
|
||||
- forward
|
||||
|
||||
## FlaxLongT5Model
|
||||
|
||||
[[autodoc]] FlaxLongT5Model
|
||||
- __call__
|
||||
- encode
|
||||
- decode
|
||||
|
||||
## FlaxLongT5ForConditionalGeneration
|
||||
|
||||
[[autodoc]] FlaxLongT5ForConditionalGeneration
|
||||
- __call__
|
||||
- encode
|
||||
- decode
|
@ -66,6 +66,7 @@ Ready-made configurations include the following architectures:
|
||||
- GPT-J
|
||||
- I-BERT
|
||||
- LayoutLM
|
||||
- LongT5
|
||||
- M2M100
|
||||
- Marian
|
||||
- mBART
|
||||
|
@ -239,6 +239,7 @@ _import_structure = {
|
||||
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
|
||||
"models.levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"],
|
||||
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
|
||||
"models.longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config"],
|
||||
"models.luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig", "LukeTokenizer"],
|
||||
"models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"],
|
||||
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
|
||||
@ -1277,6 +1278,15 @@ else:
|
||||
"LongformerSelfAttention",
|
||||
]
|
||||
)
|
||||
_import_structure["models.longt5"].extend(
|
||||
[
|
||||
"LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"LongT5EncoderModel",
|
||||
"LongT5ForConditionalGeneration",
|
||||
"LongT5Model",
|
||||
"LongT5PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.luke"].extend(
|
||||
[
|
||||
"LUKE_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -2586,6 +2596,9 @@ else:
|
||||
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
|
||||
_import_structure["models.longt5"].extend(
|
||||
["FlaxLongT5ForConditionalGeneration", "FlaxLongT5Model", "FlaxLongT5PreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.marian"].extend(
|
||||
[
|
||||
"FlaxMarianModel",
|
||||
@ -2850,6 +2863,7 @@ if TYPE_CHECKING:
|
||||
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
|
||||
from .models.levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig
|
||||
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
|
||||
from .models.longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config
|
||||
from .models.luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig, LukeTokenizer
|
||||
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer
|
||||
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
|
||||
@ -3726,6 +3740,13 @@ if TYPE_CHECKING:
|
||||
LongformerPreTrainedModel,
|
||||
LongformerSelfAttention,
|
||||
)
|
||||
from .models.longt5 import (
|
||||
LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LongT5EncoderModel,
|
||||
LongT5ForConditionalGeneration,
|
||||
LongT5Model,
|
||||
LongT5PreTrainedModel,
|
||||
)
|
||||
from .models.luke import (
|
||||
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LukeForEntityClassification,
|
||||
@ -4784,6 +4805,7 @@ if TYPE_CHECKING:
|
||||
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
|
||||
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
|
||||
from .models.gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
|
||||
from .models.longt5 import FlaxLongT5ForConditionalGeneration, FlaxLongT5Model, FlaxLongT5PreTrainedModel
|
||||
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
|
||||
from .models.mbart import (
|
||||
FlaxMBartForConditionalGeneration,
|
||||
|
@ -76,6 +76,7 @@ from . import (
|
||||
led,
|
||||
levit,
|
||||
longformer,
|
||||
longt5,
|
||||
luke,
|
||||
lxmert,
|
||||
m2m_100,
|
||||
|
@ -78,6 +78,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("led", "LEDConfig"),
|
||||
("levit", "LevitConfig"),
|
||||
("longformer", "LongformerConfig"),
|
||||
("longt5", "LongT5Config"),
|
||||
("luke", "LukeConfig"),
|
||||
("lxmert", "LxmertConfig"),
|
||||
("m2m_100", "M2M100Config"),
|
||||
@ -193,6 +194,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("levit", "LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("longt5", "LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@ -310,6 +312,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("led", "LED"),
|
||||
("levit", "LeViT"),
|
||||
("longformer", "Longformer"),
|
||||
("longt5", "LongT5"),
|
||||
("luke", "LUKE"),
|
||||
("lxmert", "LXMERT"),
|
||||
("m2m_100", "M2M100"),
|
||||
|
@ -77,6 +77,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("led", "LEDModel"),
|
||||
("levit", "LevitModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
("longt5", "LongT5Model"),
|
||||
("luke", "LukeModel"),
|
||||
("lxmert", "LxmertModel"),
|
||||
("m2m_100", "M2M100Model"),
|
||||
@ -217,6 +218,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("layoutlm", "LayoutLMForMaskedLM"),
|
||||
("led", "LEDForConditionalGeneration"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("longt5", "LongT5ForConditionalGeneration"),
|
||||
("m2m_100", "M2M100ForConditionalGeneration"),
|
||||
("marian", "MarianMTModel"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
@ -423,6 +425,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("encoder-decoder", "EncoderDecoderModel"),
|
||||
("fsmt", "FSMTForConditionalGeneration"),
|
||||
("led", "LEDForConditionalGeneration"),
|
||||
("longt5", "LongT5ForConditionalGeneration"),
|
||||
("m2m_100", "M2M100ForConditionalGeneration"),
|
||||
("marian", "MarianMTModel"),
|
||||
("mbart", "MBartForConditionalGeneration"),
|
||||
|
@ -41,6 +41,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("gpt2", "FlaxGPT2Model"),
|
||||
("gpt_neo", "FlaxGPTNeoModel"),
|
||||
("gptj", "FlaxGPTJModel"),
|
||||
("longt5", "FlaxLongT5Model"),
|
||||
("marian", "FlaxMarianModel"),
|
||||
("mbart", "FlaxMBartModel"),
|
||||
("mt5", "FlaxMT5Model"),
|
||||
@ -65,6 +66,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("bert", "FlaxBertForPreTraining"),
|
||||
("big_bird", "FlaxBigBirdForPreTraining"),
|
||||
("electra", "FlaxElectraForPreTraining"),
|
||||
("longt5", "FlaxLongT5ForConditionalGeneration"),
|
||||
("mbart", "FlaxMBartForConditionalGeneration"),
|
||||
("mt5", "FlaxMT5ForConditionalGeneration"),
|
||||
("roberta", "FlaxRobertaForMaskedLM"),
|
||||
@ -98,6 +100,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
|
||||
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
|
||||
("encoder-decoder", "FlaxEncoderDecoderModel"),
|
||||
("longt5", "FlaxLongT5ForConditionalGeneration"),
|
||||
("marian", "FlaxMarianMTModel"),
|
||||
("mbart", "FlaxMBartForConditionalGeneration"),
|
||||
("mt5", "FlaxMT5ForConditionalGeneration"),
|
||||
|
@ -137,6 +137,13 @@ else:
|
||||
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"longt5",
|
||||
(
|
||||
"T5Tokenizer" if is_sentencepiece_available() else None,
|
||||
"T5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("luke", ("LukeTokenizer", None)),
|
||||
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
|
88
src/transformers/models/longt5/__init__.py
Normal file
88
src/transformers/models/longt5/__init__.py
Normal file
@ -0,0 +1,88 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config", "LongT5OnnxConfig"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_longt5"] = [
|
||||
"LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"LongT5EncoderModel",
|
||||
"LongT5ForConditionalGeneration",
|
||||
"LongT5Model",
|
||||
"LongT5PreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_longt5"] = [
|
||||
"FlaxLongT5ForConditionalGeneration",
|
||||
"FlaxLongT5Model",
|
||||
"FlaxLongT5PreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config, LongT5OnnxConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_longt5 import (
|
||||
LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LongT5EncoderModel,
|
||||
LongT5ForConditionalGeneration,
|
||||
LongT5Model,
|
||||
LongT5PreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_longt5 import (
|
||||
FlaxLongT5ForConditionalGeneration,
|
||||
FlaxLongT5Model,
|
||||
FlaxLongT5PreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
178
src/transformers/models/longt5/configuration_longt5.py
Normal file
178
src/transformers/models/longt5/configuration_longt5.py
Normal file
@ -0,0 +1,178 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022, The LongT5 Authors and HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" LongT5 model configuration"""
|
||||
from typing import Mapping
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxSeq2SeqConfigWithPast
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"google/LongT5-Local-Base": "https://huggingface.co/google/LongT5-Local-Base/blob/main/config.json",
|
||||
"google/LongT5-Local-Large": "https://huggingface.co/google/LongT5-Local-Large/blob/main/config.json",
|
||||
"google/LongT5-TGlobal-Base": "https://huggingface.co/google/LongT5-TGlobal-Base/blob/main/config.json",
|
||||
"google/LongT5-TGlobal-Large": "https://huggingface.co/google/LongT5-TGlobal-Large/blob/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class LongT5Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`LongT5Model`] or a [`FlaxLongT5Model`]. It is
|
||||
used to instantiate a LongT5 model according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the LongT5
|
||||
[google/LongT5-Local-Base](https://huggingface.co/google/LongT5-Local-Base) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Arguments:
|
||||
vocab_size (`int`, *optional*, defaults to 32128):
|
||||
Vocabulary size of the LongT5 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`LongT5Model`].
|
||||
d_model (`int`, *optional*, defaults to 512):
|
||||
Size of the encoder layers and the pooler layer.
|
||||
d_kv (`int`, *optional*, defaults to 64):
|
||||
Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
|
||||
num_heads`.
|
||||
d_ff (`int`, *optional*, defaults to 2048):
|
||||
Size of the intermediate feed forward layer in each `LongT5Block`.
|
||||
num_layers (`int`, *optional*, defaults to 6):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_decoder_layers (`int`, *optional*):
|
||||
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
|
||||
num_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
local_radius (`int`, *optional*, defaults to 127)
|
||||
Number of tokens to the left/right for each token to locally self-attend in a local attention mechanism.
|
||||
global_block_size (`int`, *optional*, defaults to 16)
|
||||
Lenght of blocks an input sequence is divided into for a global token representation. Used only for
|
||||
`encoder_attention_type = "transient-global"`.
|
||||
relative_attention_num_buckets (`int`, *optional*, defaults to 32):
|
||||
The number of buckets to use for each attention layer.
|
||||
relative_attention_max_distance (`int`, *optional*, defaults to 128):
|
||||
The maximum distance of the longer sequences for the bucket separation.
|
||||
dropout_rate (`float`, *optional*, defaults to 0.1):
|
||||
The ratio for all dropout layers.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_factor (`float`, *optional*, defaults to 1):
|
||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||
testing).
|
||||
feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
|
||||
Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. LongT5v1.1 uses the
|
||||
`"gated-gelu"` feed forward projection. Original LongT5 implementation uses `"gated-gelu"`.
|
||||
encoder_attention_type (`string`, *optional*, defaults to `"local"`):
|
||||
Type of encoder attention to be used. Should be one of `"local"` or `"transient-global"`, which are
|
||||
supported by LongT5 implementation.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
"""
|
||||
model_type = "longt5"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32128,
|
||||
d_model=512,
|
||||
d_kv=64,
|
||||
d_ff=2048,
|
||||
num_layers=6,
|
||||
num_decoder_layers=None,
|
||||
num_heads=8,
|
||||
local_radius=127,
|
||||
global_block_size=16,
|
||||
relative_attention_num_buckets=32,
|
||||
relative_attention_max_distance=128,
|
||||
dropout_rate=0.1,
|
||||
layer_norm_epsilon=1e-6,
|
||||
initializer_factor=1.0,
|
||||
feed_forward_proj="relu",
|
||||
is_encoder_decoder=True,
|
||||
encoder_attention_type="local",
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.d_kv = d_kv
|
||||
self.d_ff = d_ff
|
||||
self.num_layers = num_layers
|
||||
# default = symmetry
|
||||
self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers
|
||||
self.num_heads = num_heads
|
||||
self.local_radius = local_radius
|
||||
self.global_block_size = global_block_size
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.relative_attention_max_distance = relative_attention_max_distance
|
||||
self.dropout_rate = dropout_rate
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_factor = initializer_factor
|
||||
self.feed_forward_proj = feed_forward_proj
|
||||
self.encoder_attention_type = encoder_attention_type
|
||||
self.use_cache = use_cache
|
||||
|
||||
act_info = self.feed_forward_proj.split("-")
|
||||
self.dense_act_fn = act_info[-1]
|
||||
self.is_gated_act = act_info[0] == "gated"
|
||||
|
||||
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
|
||||
raise ValueError(
|
||||
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
|
||||
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
|
||||
"'gated-gelu' or 'relu'"
|
||||
)
|
||||
|
||||
# for backwards compatibility
|
||||
if feed_forward_proj == "gated-gelu":
|
||||
self.dense_act_fn = "gelu_new"
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LongT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = {
|
||||
"input_ids": {0: "batch", 1: "encoder_sequence"},
|
||||
"attention_mask": {0: "batch", 1: "encoder_sequence"},
|
||||
}
|
||||
if self.use_past:
|
||||
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
|
||||
common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
return 13
|
@ -0,0 +1,214 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert T5/LongT5X checkpoints from the original repository to JAX/FLAX model. This script is an extension of
|
||||
'src/transformers/models/t5/convert_t5x_checkpoint_to_flax.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from t5x import checkpoints
|
||||
from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM
|
||||
|
||||
|
||||
def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
|
||||
config = AutoConfig.from_pretrained(config_name)
|
||||
flax_model = FlaxAutoModelForSeq2SeqLM.from_config(config=config)
|
||||
t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
|
||||
|
||||
split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
|
||||
|
||||
if config.model_type == "t5":
|
||||
encoder_attn_name = "SelfAttention"
|
||||
if config.model_type == "longt5" and config.encoder_attention_type == "local":
|
||||
encoder_attn_name = "LocalSelfAttention"
|
||||
elif config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
|
||||
encoder_attn_name = "TransientGlobalSelfAttention"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Given config is expected to have `model_type='t5'`, or `model_type='longt5` with `encoder_attention_type`"
|
||||
" attribute with a value from ['local', 'transient-global]."
|
||||
)
|
||||
|
||||
# Encoder
|
||||
for layer_index in range(config.num_layers):
|
||||
layer_name = f"layers_{str(layer_index)}"
|
||||
|
||||
# Self-Attention
|
||||
t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
|
||||
t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
|
||||
t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
|
||||
t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
|
||||
|
||||
# Global input layer norm
|
||||
if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
|
||||
t5x_global_layer_norm = t5x_model["target"]["encoder"][layer_name]["attention"]["T5LayerNorm_0"]["scale"]
|
||||
|
||||
# Layer Normalization
|
||||
t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
|
||||
|
||||
if split_mlp_wi:
|
||||
t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
|
||||
t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
|
||||
else:
|
||||
t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
|
||||
|
||||
t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
|
||||
|
||||
# Layer Normalization
|
||||
t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
|
||||
|
||||
# Assigning
|
||||
flax_model_encoder_layer_block = flax_model.params["encoder"]["block"][str(layer_index)]["layer"]
|
||||
flax_model_encoder_layer_block["0"][encoder_attn_name]["k"]["kernel"] = t5x_attention_key
|
||||
flax_model_encoder_layer_block["0"][encoder_attn_name]["o"]["kernel"] = t5x_attention_out
|
||||
flax_model_encoder_layer_block["0"][encoder_attn_name]["q"]["kernel"] = t5x_attention_query
|
||||
flax_model_encoder_layer_block["0"][encoder_attn_name]["v"]["kernel"] = t5x_attention_value
|
||||
|
||||
flax_model_encoder_layer_block["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
|
||||
|
||||
# Global input layer norm
|
||||
if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
|
||||
flax_model_encoder_layer_block["0"][encoder_attn_name]["global_input_layer_norm"][
|
||||
"weight"
|
||||
] = t5x_global_layer_norm
|
||||
|
||||
if split_mlp_wi:
|
||||
flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
|
||||
flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
|
||||
else:
|
||||
flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
|
||||
|
||||
flax_model_encoder_layer_block["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
|
||||
flax_model_encoder_layer_block["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
|
||||
|
||||
flax_model.params["encoder"]["block"][str(layer_index)]["layer"] = flax_model_encoder_layer_block
|
||||
|
||||
# Only for layer 0:
|
||||
t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
|
||||
flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["relative_attention_bias"][
|
||||
"embedding"
|
||||
] = t5x_encoder_rel_embedding
|
||||
|
||||
# Side/global relative position_bias + layer norm
|
||||
if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
|
||||
t5x_encoder_global_rel_embedding = t5x_model["target"]["encoder"]["side_relpos_bias"]["rel_embedding"].T
|
||||
flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["global_relative_attention_bias"][
|
||||
"embedding"
|
||||
] = t5x_encoder_global_rel_embedding
|
||||
|
||||
# Assigning
|
||||
t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
|
||||
flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
|
||||
|
||||
# Decoder
|
||||
for layer_index in range(config.num_layers):
|
||||
layer_name = f"layers_{str(layer_index)}"
|
||||
|
||||
# Self-Attention
|
||||
t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
|
||||
t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
|
||||
t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
|
||||
t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
|
||||
|
||||
# Layer Normalization
|
||||
t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][
|
||||
"scale"
|
||||
]
|
||||
|
||||
# Encoder-Decoder-Attention
|
||||
t5x_enc_dec_attention_module = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]
|
||||
t5x_enc_dec_attention_key = t5x_enc_dec_attention_module["key"]["kernel"]
|
||||
t5x_enc_dec_attention_out = t5x_enc_dec_attention_module["out"]["kernel"]
|
||||
t5x_enc_dec_attention_query = t5x_enc_dec_attention_module["query"]["kernel"]
|
||||
t5x_enc_dec_attention_value = t5x_enc_dec_attention_module["value"]["kernel"]
|
||||
|
||||
# Layer Normalization
|
||||
t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
|
||||
|
||||
# MLP
|
||||
if split_mlp_wi:
|
||||
t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
|
||||
t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
|
||||
else:
|
||||
t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
|
||||
|
||||
t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
|
||||
|
||||
# Layer Normalization
|
||||
tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
|
||||
|
||||
# Assigning
|
||||
flax_model_decoder_layer_block = flax_model.params["decoder"]["block"][str(layer_index)]["layer"]
|
||||
flax_model_decoder_layer_block["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
|
||||
flax_model_decoder_layer_block["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
|
||||
flax_model_decoder_layer_block["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
|
||||
flax_model_decoder_layer_block["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
|
||||
|
||||
flax_model_decoder_layer_block["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
|
||||
|
||||
flax_model_decoder_layer_block["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
|
||||
flax_model_decoder_layer_block["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
|
||||
flax_model_decoder_layer_block["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
|
||||
flax_model_decoder_layer_block["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
|
||||
|
||||
flax_model_decoder_layer_block["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
|
||||
|
||||
if split_mlp_wi:
|
||||
flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
|
||||
flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
|
||||
else:
|
||||
flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
|
||||
|
||||
flax_model_decoder_layer_block["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
|
||||
|
||||
flax_model_decoder_layer_block["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
|
||||
|
||||
flax_model.params["decoder"]["block"][str(layer_index)]["layer"] = flax_model_decoder_layer_block
|
||||
|
||||
# Decoder Normalization
|
||||
tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
|
||||
flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
|
||||
|
||||
# Only for layer 0:
|
||||
t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
|
||||
flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][
|
||||
"embedding"
|
||||
] = t5x_decoder_rel_embedding
|
||||
|
||||
# Token Embeddings
|
||||
tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
|
||||
flax_model.params["shared"]["embedding"] = tx5_token_embeddings
|
||||
|
||||
# LM Head (only in v1.1 and LongT5 checkpoints)
|
||||
if "logits_dense" in t5x_model["target"]["decoder"]:
|
||||
flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
|
||||
|
||||
flax_model.save_pretrained(flax_dump_folder_path)
|
||||
print("T5X Model was sucessfully converted!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the T5X checkpoint."
|
||||
)
|
||||
parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of LongT5/T5 model.")
|
||||
parser.add_argument(
|
||||
"--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
|
2401
src/transformers/models/longt5/modeling_flax_longt5.py
Normal file
2401
src/transformers/models/longt5/modeling_flax_longt5.py
Normal file
File diff suppressed because it is too large
Load Diff
2192
src/transformers/models/longt5/modeling_longt5.py
Normal file
2192
src/transformers/models/longt5/modeling_longt5.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -281,6 +281,13 @@ class FeaturesManager:
|
||||
"token-classification",
|
||||
onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
|
||||
),
|
||||
"longt5": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
onnx_config_cls="models.longt5.LongT5OnnxConfig",
|
||||
),
|
||||
"marian": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
|
@ -725,6 +725,27 @@ class FlaxGPTJPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxLongT5ForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxLongT5Model(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxLongT5PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxMarianModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
@ -2605,6 +2605,37 @@ class LongformerSelfAttention(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class LongT5EncoderModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LongT5ForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LongT5Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LongT5PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
9
test_results.txt
Normal file
9
test_results.txt
Normal file
@ -0,0 +1,9 @@
|
||||
background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
|
||||
|
||||
|
||||
background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
|
||||
|
||||
|
||||
background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
|
||||
|
||||
|
0
tests/models/longt5/__init__.py
Normal file
0
tests/models/longt5/__init__.py
Normal file
757
tests/models/longt5/test_modeling_flax_longt5.py
Normal file
757
tests/models/longt5/test_modeling_flax_longt5.py
Normal file
@ -0,0 +1,757 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 Google LongT5 Authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import is_flax_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
require_flax,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
slow,
|
||||
)
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import os
|
||||
|
||||
# The slow tests are often failing with OOM error on GPU
|
||||
# This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
|
||||
# but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
|
||||
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict
|
||||
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING, AutoTokenizer, LongT5Config
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
from transformers.models.longt5.modeling_flax_longt5 import (
|
||||
FlaxLongT5ForConditionalGeneration,
|
||||
FlaxLongT5Model,
|
||||
shift_tokens_right,
|
||||
)
|
||||
|
||||
|
||||
class FlaxLongT5ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
vocab_size=99,
|
||||
batch_size=13,
|
||||
encoder_seq_length=7,
|
||||
decoder_seq_length=9,
|
||||
local_radius=5,
|
||||
encoder_attention_type="local",
|
||||
global_block_size=3,
|
||||
# For common tests
|
||||
is_training=True,
|
||||
use_attention_mask=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
d_ff=37,
|
||||
relative_attention_num_buckets=8,
|
||||
dropout_rate=0.1,
|
||||
initializer_factor=0.002,
|
||||
eos_token_id=1,
|
||||
pad_token_id=0,
|
||||
decoder_start_token_id=0,
|
||||
scope=None,
|
||||
decoder_layers=None,
|
||||
):
|
||||
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.encoder_seq_length = encoder_seq_length
|
||||
self.decoder_seq_length = decoder_seq_length
|
||||
self.local_radius = local_radius
|
||||
self.block_len = local_radius + 1
|
||||
self.encoder_attention_type = encoder_attention_type
|
||||
self.global_block_size = global_block_size
|
||||
# For common tests
|
||||
self.seq_length = self.decoder_seq_length
|
||||
self.is_training = is_training
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_ff = d_ff
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_factor = initializer_factor
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.scope = None
|
||||
self.decoder_layers = decoder_layers
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||
|
||||
attention_mask = None
|
||||
decoder_attention_mask = None
|
||||
if self.use_attention_mask:
|
||||
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
|
||||
decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||
|
||||
config = LongT5Config(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
d_ff=self.d_ff,
|
||||
d_kv=self.hidden_size // self.num_attention_heads,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_decoder_layers=self.decoder_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_factor=self.initializer_factor,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
local_radius=self.local_radius,
|
||||
encoder_attention_type=self.encoder_attention_type,
|
||||
global_block_size=self.global_block_size,
|
||||
)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
):
|
||||
model = FlaxLongT5Model(config=config)
|
||||
result = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
decoder_output = result.last_hidden_state
|
||||
encoder_output = result.encoder_last_hidden_state
|
||||
|
||||
self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.decoder_seq_length, self.hidden_size))
|
||||
|
||||
def check_use_cache_forward_with_attn_mask(
|
||||
self,
|
||||
model_class_name,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
):
|
||||
max_decoder_length = 20
|
||||
model = model_class_name(config)
|
||||
|
||||
encoder_outputs = model.encode(input_ids)
|
||||
|
||||
# prevent fully zero'd out attention mask
|
||||
decoder_attention_mask = jnp.ones_like(decoder_attention_mask)
|
||||
|
||||
decoder_attention_mask_cache = jnp.concatenate(
|
||||
[
|
||||
decoder_attention_mask,
|
||||
jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
|
||||
|
||||
outputs_cache = model.decode(
|
||||
decoder_input_ids[:, :-1],
|
||||
encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
outputs_cache_next = model.decode(
|
||||
decoder_input_ids[:, -1:],
|
||||
encoder_outputs,
|
||||
past_key_values=outputs_cache.past_key_values,
|
||||
decoder_attention_mask=decoder_attention_mask_cache,
|
||||
)
|
||||
|
||||
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
||||
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxLongT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxLongT5Model, FlaxLongT5ForConditionalGeneration) if is_flax_available() else ()
|
||||
all_generative_model_classes = (FlaxLongT5ForConditionalGeneration,) if is_flax_available() else ()
|
||||
is_encoder_decoder = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxLongT5ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_v1_1(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# check that gated gelu feed forward and different word embeddings work
|
||||
config = config_and_inputs[0]
|
||||
config.tie_word_embeddings = False
|
||||
config.feed_forward_proj = "gated-gelu"
|
||||
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
|
||||
|
||||
def test_use_cache_forward_with_attn_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for model_class in self.all_model_classes:
|
||||
self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs)
|
||||
|
||||
def test_encode(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def encode_jitted(input_ids, attention_mask=None, **kwargs):
|
||||
return model.encode(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
def test_decode(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
model = model_class(config)
|
||||
encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||
|
||||
prepared_inputs_dict = {
|
||||
"decoder_input_ids": inputs_dict["decoder_input_ids"],
|
||||
"decoder_attention_mask": inputs_dict["decoder_attention_mask"],
|
||||
"encoder_outputs": encoder_outputs,
|
||||
}
|
||||
|
||||
@jax.jit
|
||||
def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs):
|
||||
return model.decode(
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
def test_shift_right(self):
|
||||
decoder_start_token_id = 0
|
||||
pad_token_id = 1
|
||||
labels = np.arange(2, 102).reshape(5, 20)
|
||||
labels[:2, 15:] = -100
|
||||
|
||||
decoder_input_ids = shift_tokens_right(labels, pad_token_id, decoder_start_token_id)
|
||||
np_decoder_input_ids = np.array(decoder_input_ids)
|
||||
|
||||
padded_slice = np_decoder_input_ids[:2, (15 + 1) :]
|
||||
self.assertTrue((padded_slice == 1).all())
|
||||
|
||||
not_padded_slice = np_decoder_input_ids[2:, 1:]
|
||||
rolled_labels = np.roll(labels[2:], 1)[:, 1:]
|
||||
self.assertTrue((not_padded_slice == rolled_labels).all())
|
||||
self.assertTrue((np_decoder_input_ids[:, 0] == 0).all())
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
def test_save_load_from_base(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
def test_save_load_to_base(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
||||
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
block_len = getattr(self.model_tester, "block_len", None)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
correct_outlen = 5
|
||||
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
decoder_seq_length,
|
||||
encoder_key_length,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
|
||||
)
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
class FlaxLongT5TGlobalModelTest(FlaxLongT5ModelTest):
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxLongT5ModelTester(self, encoder_attention_type="transient-global")
|
||||
self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
||||
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
block_len = getattr(self.model_tester, "block_len", None)
|
||||
global_block_size = getattr(self.model_tester, "global_block_size", None)
|
||||
global_seq_len = encoder_seq_length // global_block_size
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
correct_outlen = 5
|
||||
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
decoder_seq_length,
|
||||
encoder_key_length,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
|
||||
)
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@require_flax
|
||||
class FlaxLongT5ModelIntegrationTests(unittest.TestCase):
|
||||
model_path = "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
|
||||
|
||||
def expected_summary(self):
|
||||
return [
|
||||
"background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in"
|
||||
" developing world . it provides an excellent resolution for visualization of the coronary arteries for"
|
||||
" catheter - based or operating interventions . although the association of this technique with major"
|
||||
" complications such as mortality is highly uncommon , it is frequently associated with various cardiac"
|
||||
" and noncardiac complications . computed tomography coronary angiography is a promising technique for the"
|
||||
" evaluation of cad noninvasively . it assesses disease within the coronary artery and provides"
|
||||
" qualitative and quantitative information about nonobstructive atherosclerotic plaque"
|
||||
]
|
||||
|
||||
@slow
|
||||
def test_summarization(self):
|
||||
model = FlaxLongT5ForConditionalGeneration.from_pretrained(self.model_path)
|
||||
tok = AutoTokenizer.from_pretrained(self.model_path)
|
||||
|
||||
ARTICLE = """coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . \n it provides an excellent resolution for visualization of the coronary arteries for catheter - based or operating interventions . \n
|
||||
although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications . computed tomography ( ct ) coronary angiography is
|
||||
a promising technique for the evaluation of cad noninvasively . \n it assesses disease within the coronary artery and provides qualitative and quantitative information about nonobstructive atherosclerotic plaque burden within the vessel
|
||||
wall . \n thus , ct angiography - based disease evaluation may provide clinically more significant information than conventional angiography . the introduction of multi - slice computed tomography ( msct ) technology such as 64-slice , 12
|
||||
8-slice , 256-slice , and now 320-slice msct has produced a high diagnostic accuracy of ct coronary angiography . \n it has consistently showed to have a very high negative predictive value ( well above 90% ) in ruling out patients with s
|
||||
ignificant cad defined as coronary luminal stenosis of > 50% . \n the american college of cardiology / american heart association recommends that coronary angiography should be performed before valve surgery in men aged > 40 years , women
|
||||
aged > 35 years with coronary risk factors and in postmenopausal women . \n the prevalence of cad in patients undergoing valve replacement is 2040% in developed countries . in the previous studies , \n the incidence of angiographically p
|
||||
roven cad in acquired valvular diseases has been shown to vary widely from 9% to 41% . in aortic stenosis , \n we aimed to report the diagnostic performance of 128-slice ct coronary angiography in 50 patients undergoing for major noncoron
|
||||
ary cardiac surgery referred for diagnostic invasive coronary angiography to assess the extent and severity of coronary stenosis . \n during january 2013 to december 2014 , we enrolled fifty major noncoronary cardiac surgery patients sche
|
||||
duled for invasive coronary angiography who fulfilled the following inclusion criteria of age 40 years , having low or intermediate probability of cad , left ventricular ejection fraction ( lvef ) > 35% , and patient giving informed conse
|
||||
nt for undergoing msct and conventional coronary angiography . \n those having any contraindication for contrast injection , lvef < 35% , high pretest probability of cad , and hemodynamic instability were excluded from the study . \n pati
|
||||
ents with heart rates of > 70 bpm received ( unless they had known overt heart failure or electrocardiogram ( ecg ) atrioventricular conduction abnormalities ) a single oral dose of 100 mg metoprolol 45 min before the scan . \n patients w
|
||||
ith heart rates of > 80 bpm received an additional oral dose of metoprolol if not contraindicated . \n all patients were scanned with a 128-slice ct scanner ( siemens , somatom definition as ) equipped with a new feature in msct technolog
|
||||
y , so - called z - axis flying - focus technology . \n the central 32 detector rows acquire 0.6-mm slices , and the flying - focus spot switches back and forth between 2 z positions between each reading . \n two slices per detector row a
|
||||
re acquired , which results in a higher oversampling rate in the z - axis , thereby reducing artifacts related to the spiral acquisition and improving spatial resolution down to 0.4 mm . \n a bolus of 6580 ml contrast material ( omnipaque
|
||||
) was injected through an arm vein at a flow rate of 5 ml / s . \n a bolus tracking technique was used to synchronize the arrival of contrast in the coronary arteries with the initiation of the scan . to monitor the arrival of contrast m
|
||||
aterial , \n axial scans were obtained at the level of the ascending aorta with a delay of 10 s after the start of the contrast injection . \n the scan was automatically started when a threshold of 150 hounsfield units was reached in a re
|
||||
gion of interest positioned in the ascending aorta . \n images were reconstructed with ecg gating to obtain optimal , motion - free image quality . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a s
|
||||
ingle observer unaware of the multi - slice ct results identified coronary lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiograp
|
||||
hy . \n lesions were classified as having nonsignificant disease ( luminal irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean
|
||||
lumen diameter reduction was 50% using a validated quantitative coronary angiography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiograp
|
||||
hy . \n total calcium scores of all patients were calculated with dedicated software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of th
|
||||
e number , areas , and peak hounsfield units of the detected calcified lesions . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were
|
||||
used to identify coronary lesions and ( curved ) multiplanar reconstructions to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the di
|
||||
agnostic performance of ct coronary angiography for the detection of significant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and
|
||||
positive and negative likelihood ratios with the corresponding exact 95% of confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease p
|
||||
er vessel ) , and patient by patient ( no or any disease per patient ) . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a single observer unaware of the multi - slice ct results identified coronary
|
||||
lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiography . \n lesions were classified as having nonsignificant disease ( luminal
|
||||
irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean lumen diameter reduction was 50% using a validated quantitative coronary an
|
||||
giography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiography . \n total calcium scores of all patients were calculated with dedicated
|
||||
software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of the number , areas , and peak hounsfield units of the detected calcified lesi
|
||||
ons . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were used to identify coronary lesions and ( curved ) multiplanar reconstruction
|
||||
s to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the diagnostic performance of ct coronary angiography for the detection of signif
|
||||
icant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and positive and negative likelihood ratios with the corresponding exact 95% of
|
||||
confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease per vessel ) , and patient by patient ( no or any disease per patient ) . \n
|
||||
in this study , 29 ( 58% ) subjects were female , and 21 ( 42% ) were male showing an average age of 50.36 8.39 years . \n of fifty patients 24 ( 48% ) , 13 ( 26% ) , eight ( 16% ) , and five ( 10% ) underwent mitral valve replacement ,
|
||||
double valve replacement ( dvr ) , aortic valve replacement , and other surgeries , respectively . \n high distribution of cad risk factors such as hypertension ( 24% ) , smoking ( 22% ) , and dyslipidemia ( 18% ) was observed in the stu
|
||||
dy group . \n the mean creatinine level was 0.766 0.17 and average dye used in conventional angiography was 48.5 26.6 whereas for ct angiography it was 72.8 6.32 . \n average radiation dose in conventional coronary angiography and msct
|
||||
coronary angiography was 5.2 msv and 9.2 msv , respectively . \n the majority of the patients had sinus rhythm ( 68% ) , whereas atrial fibrillation was found in 32% of the subjects . \n patients included in the study had low to intermed
|
||||
iate probability of cad . in this study , three patients had complications after conventional angiography . \n complications were of local site hematoma , acute kidney injury managed conservatively , and acute heart failure . \n a patient
|
||||
who developed hematoma was obese female patients with body mass index > 30 kg / m . \n the patient suffered from pseudoaneurysm , had hospitalized for 9 days , which leads to increased morbidity and cost of hospital stay . \n the diagnos
|
||||
tic accuracy of ct coronary angiography was evaluated regarding true positive , true negative values and is presented in table 1 . the overall sensitivity and \n specificity of ct angiography technique was 100% ( 95% ci : 39.76%100% ) and
|
||||
91.30% ( 95% ci : 79.21%97.58% ) , respectively [ table 2 ] . \n the positive predictive value ( 50% ; 95% ci : 15.70%84.30% ) and negative predictive value ( 100% ; 95% ci : 91.59%100% ) of ct angiography were also fairly high in these
|
||||
patients . \n recent reports from multiple studies demonstrated that recent - generation msct scanners showed promise for noninvasive detection of coronary stenosis however , until now no studies were found regarding the clinical efficacy
|
||||
or prognostic value of 128-slice ct coronary angiography versus conventional invasive coronary angiography in the diagnosis of patients planned for major noncoronary surgeries such as dvr , bentall , atrial septal defect closure , etc .
|
||||
in our study , we reported 8% cad prevalence in patients planned for major noncoronary cardiac surgery . \n we performed conventional and msct coronary angiography in all patients and the results showed that ct coronary angiography with i
|
||||
nvasive coronary angiography as the reference standard had a considerably high sensitivity ( 100% ) and specificity ( 95.65% ) . \n the health economic model using invasive coronary angiography as the reference standard showed that at a p
|
||||
retest probability of cad of 70% or lower , ct coronary angiography resulted in lower cost per patient with a true positive diagnosis . at a pretest probability of cad of 70% or higher , invasive coronary angiography was associated with a
|
||||
lower cost per patient with a true positive diagnosis . in our study population , \n two patients developed local site complications in the form of hematoma and pseudoaneurysm after conventional angiography . \n hence , msct coronary ang
|
||||
iography will be more favorable in female obese patients with intermediate likelihood of cad . \n hence , msct coronary angiography will be cost - effective in patients of valvular heart diseases . \n however , ct angiography suffers from
|
||||
a drawback that average amount of dye used in msct coronary angiography were 72.8 6.32 ml which is higher than average amount of dye required for conventional angiography ( 48.6 26.6 ml ) . \n hence , the use of ct coronary angiography
|
||||
could not be used in patients with known renal dysfunction , where reduction of contrast dye load is highly advocated . \n our results show that 128-slice ct coronary angiography is a reliable technique to detect coronary stenosis in pat
|
||||
ients planned for noncoronary cardiac surgery . \n although there has been important technological progress in the development of ct coronary angiography , its clinical application remains limited . \n a study wth large numbers of patient
|
||||
s is required for the recommendation of only ct coronary angiography for the coronary evaluation in major non - cardiac surgeries . \n mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , guja
|
||||
rat , india ) . \n u.n . mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , gujarat , india ) . \n """
|
||||
|
||||
dct = tok(
|
||||
[ARTICLE],
|
||||
max_length=1024,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
hypotheses_batch = model.generate(
|
||||
**dct,
|
||||
num_beams=4,
|
||||
length_penalty=2.0,
|
||||
max_length=142,
|
||||
min_length=56,
|
||||
do_sample=False,
|
||||
early_stopping=True,
|
||||
).sequences
|
||||
|
||||
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
self.assertListEqual(
|
||||
self.expected_summary(),
|
||||
decoded,
|
||||
)
|
1315
tests/models/longt5/test_modeling_longt5.py
Normal file
1315
tests/models/longt5/test_modeling_longt5.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -213,6 +213,8 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||
("blenderbot-small", "facebook/blenderbot_small-90M"),
|
||||
("blenderbot", "facebook/blenderbot-400M-distill"),
|
||||
("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"),
|
||||
("longt5", "google/LongT5-Local-Base"),
|
||||
("longt5", "google/LongT5-TGlobal-Base"),
|
||||
}
|
||||
|
||||
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||
|
@ -18,6 +18,7 @@ from transformers import (
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
LEDConfig,
|
||||
LongT5Config,
|
||||
SummarizationPipeline,
|
||||
T5Config,
|
||||
pipeline,
|
||||
@ -54,8 +55,8 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
|
||||
)
|
||||
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
|
||||
|
||||
if not isinstance(model.config, (T5Config, LEDConfig)):
|
||||
# LED, T5 can handle it.
|
||||
if not isinstance(model.config, (T5Config, LongT5Config, LEDConfig)):
|
||||
# LED, T5, LongT5 can handle it.
|
||||
# Too long.
|
||||
with self.assertRaises(Exception):
|
||||
outputs = summarizer("This " * 1000)
|
||||
|
@ -36,6 +36,7 @@ PATH_TO_DOC = "docs/source/en"
|
||||
# Update this list with models that are supposed to be private.
|
||||
PRIVATE_MODELS = [
|
||||
"DPRSpanPredictor",
|
||||
"LongT5Stack",
|
||||
"RealmBertModel",
|
||||
"T5Stack",
|
||||
"TFDPRSpanPredictor",
|
||||
|
@ -36,6 +36,7 @@ src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
|
||||
src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
|
||||
src/transformers/models/longformer/modeling_longformer.py
|
||||
src/transformers/models/longformer/modeling_tf_longformer.py
|
||||
src/transformers/models/longt5/modeling_longt5.py
|
||||
src/transformers/models/marian/modeling_marian.py
|
||||
src/transformers/models/mbart/modeling_mbart.py
|
||||
src/transformers/models/mobilebert/modeling_mobilebert.py
|
||||
|
Loading…
Reference in New Issue
Block a user