diff --git a/README.md b/README.md index 59e36e73a28..0f3ff8b3998 100644 --- a/README.md +++ b/README.md @@ -284,6 +284,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. +1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang. 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. 1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. diff --git a/README_ko.md b/README_ko.md index 54dc8105442..b3df8b803e5 100644 --- a/README_ko.md +++ b/README_ko.md @@ -265,6 +265,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. +1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang. 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. 1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. diff --git a/README_zh-hans.md b/README_zh-hans.md index d65d1db19bf..96b000c7d70 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -289,6 +289,7 @@ conda install -c huggingface transformers 1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。 1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (来自 Meta AI) 伴随论文 [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) 由 Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze 发布。 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。 +1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (来自 Google AI) released 伴随论文 [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) 由 Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang 发布。 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (来自 Studio Ousia) 伴随论文 [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) 由 Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto 发布。 1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (来自 UNC Chapel Hill) 伴随论文 [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) 由 Hao Tan and Mohit Bansal 发布。 1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (来自 Facebook) 伴随论文 [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 由 Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 1f020a7326d..eeafb1c5756 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -301,6 +301,7 @@ conda install -c huggingface transformers 1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. +1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang. 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. 1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7f4c4c68fcf..3d6923bffca 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -256,6 +256,8 @@ title: LeViT - local: model_doc/longformer title: Longformer + - local: model_doc/longt5 + title: LongT5 - local: model_doc/luke title: LUKE - local: model_doc/lxmert diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 6043a9ba72f..349174d35ad 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -107,6 +107,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret 1. **[LED](model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LeViT](model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze. 1. **[Longformer](model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. +1. **[LongT5](model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang. 1. **[LUKE](model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. 1. **[LXMERT](model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M-CTC-T](model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. @@ -233,6 +234,7 @@ Flax), PyTorch, and/or TensorFlow. | LED | ✅ | ✅ | ✅ | ✅ | ❌ | | LeViT | ❌ | ❌ | ✅ | ❌ | ❌ | | Longformer | ✅ | ✅ | ✅ | ✅ | ❌ | +| LongT5 | ❌ | ❌ | ✅ | ❌ | ✅ | | LUKE | ✅ | ❌ | ✅ | ❌ | ❌ | | LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ | | M-CTC-T | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/longt5.mdx b/docs/source/en/model_doc/longt5.mdx new file mode 100644 index 00000000000..27a1d685158 --- /dev/null +++ b/docs/source/en/model_doc/longt5.mdx @@ -0,0 +1,121 @@ + + +# LongT5 + +## Overview + +The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) +by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung and Yinfei Yang. It's an +encoder-decoder transformer pre-trained in a text-to-text denoising generative setting. LongT5 model is an extension of +T5 model, and it enables using one of the two different efficient attention mechanisms - (1) Local attention, or (2) +Transient-Global attention. + + +The abstract from the paper is the following: + +*Recent work has shown that either (1) increasing the input length or (2) increasing model size can improve the +performance of Transformer-based neural models. In this paper, we present a new model, called LongT5, with which we +explore the effects of scaling both the input length and model size at the same time. Specifically, we integrated +attention ideas from long-input transformers (ETC), and adopted pre-training strategies from summarization pre-training +(PEGASUS) into the scalable T5 architecture. The result is a new attention mechanism we call {\em Transient Global} +(TGlobal), which mimics ETC's local/global attention mechanism, but without requiring additional side-inputs. We are +able to achieve state-of-the-art results on several summarization tasks and outperform the original T5 models on +question answering tasks.* + +Tips: + +- [`LongT5ForConditionalGeneration`] is an extension of [`T5ForConditionalGeneration`] exchanging the traditional +encoder *self-attention* layer with efficient either *local* attention or *transient-global* (*tglobal*) attention. +- Unlike the T5 model, LongT5 does not use a task prefix. Furthermore, it uses a different pre-training objective +inspired by the pre-training of `[PegasusForConditionalGeneration]`. +- LongT5 model is designed to work efficiently and very well on long-range *sequence-to-sequence* tasks where the +input sequence exceeds commonly used 512 tokens. It is capable of handling input sequences of a length up to 16,384 tokens. +- For *Local Attention*, the sparse sliding-window local attention operation allows a given token to attend only `r` +tokens to the left and right of it (with `r=127` by default). *Local Attention* does not introduce any new parameters +to the model. The complexity of the mechanism is linear in input sequence length `l`: `O(l*r)`. +- *Transient Global Attention* is an extension of the *Local Attention*. It, furthermore, allows each input token to +interact with all other tokens in the layer. This is achieved via splitting an input sequence into blocks of a fixed +length `k` (with a default `k=16`). Then, a global token for such a block is obtained via summing and normalizing the embeddings of every token +in the block. Thanks to this, the attention allows each token to attend to both nearby tokens like in Local attention, and +also every global token like in the case of standard global attention (*transient* represents the fact the global tokens +are constructed dynamically within each attention operation). As a consequence, *TGlobal* attention introduces +a few new parameters -- global relative position biases and a layer normalization for global token's embedding. +The complexity of this mechanism is `O(l(r + l/k))`. +- An example showing how to evaluate a fine-tuned LongT5 model on the [pubmed dataset](https://huggingface.co/datasets/scientific_papers) is below. + +```python +>>> import evaluate +>>> from datasets import load_dataset +>>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration + +>>> dataset = load_dataset("scientific_papers", "pubmed", split="validation") +>>> model = ( +... LongT5ForConditionalGeneration.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps") +... .to("cuda") +... .half() +... ) +>>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps") + + +>>> def generate_answers(batch): +... inputs_dict = tokenizer( +... batch["article"], max_length=16384, padding="max_length", truncation=True, return_tensors="pt" +... ) +... input_ids = inputs_dict.input_ids.to("cuda") +... attention_mask = inputs_dict.attention_mask.to("cuda") +... output_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=512, num_beams=2) +... batch["predicted_abstract"] = tokenizer.batch_decode(output_ids, skip_special_tokens=True) +... return batch + + +>>> result = dataset.map(generate_answer, batched=True, batch_size=2) +>>> rouge = evaluate.load("rouge") +>>> rouge.compute(predictions=result["predicted_abstract"], references=result["abstract"]) +``` + +This model was contributed by [stancld](https://huggingface.co/stancld). +The original code can be found [here](https://github.com/google-research/longt5). + + +## LongT5Config + +[[autodoc]] LongT5Config + +## LongT5Model + +[[autodoc]] LongT5Model + - forward + +## LongT5ForConditionalGeneration + +[[autodoc]] LongT5ForConditionalGeneration + - forward + +## LongT5EncoderModel + +[[autodoc]] LongT5EncoderModel + - forward + +## FlaxLongT5Model + +[[autodoc]] FlaxLongT5Model + - __call__ + - encode + - decode + +## FlaxLongT5ForConditionalGeneration + +[[autodoc]] FlaxLongT5ForConditionalGeneration + - __call__ + - encode + - decode diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index d6579ec239d..ce52d082145 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -66,6 +66,7 @@ Ready-made configurations include the following architectures: - GPT-J - I-BERT - LayoutLM +- LongT5 - M2M100 - Marian - mBART diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b55ad04dc7e..5e0df3ad6dd 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -239,6 +239,7 @@ _import_structure = { "models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"], "models.levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"], "models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"], + "models.longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config"], "models.luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig", "LukeTokenizer"], "models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"], "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], @@ -1277,6 +1278,15 @@ else: "LongformerSelfAttention", ] ) + _import_structure["models.longt5"].extend( + [ + "LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST", + "LongT5EncoderModel", + "LongT5ForConditionalGeneration", + "LongT5Model", + "LongT5PreTrainedModel", + ] + ) _import_structure["models.luke"].extend( [ "LUKE_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2586,6 +2596,9 @@ else: ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"] ) _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"]) + _import_structure["models.longt5"].extend( + ["FlaxLongT5ForConditionalGeneration", "FlaxLongT5Model", "FlaxLongT5PreTrainedModel"] + ) _import_structure["models.marian"].extend( [ "FlaxMarianModel", @@ -2850,6 +2863,7 @@ if TYPE_CHECKING: from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer from .models.levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer + from .models.longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config from .models.luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig, LukeTokenizer from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config @@ -3726,6 +3740,13 @@ if TYPE_CHECKING: LongformerPreTrainedModel, LongformerSelfAttention, ) + from .models.longt5 import ( + LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST, + LongT5EncoderModel, + LongT5ForConditionalGeneration, + LongT5Model, + LongT5PreTrainedModel, + ) from .models.luke import ( LUKE_PRETRAINED_MODEL_ARCHIVE_LIST, LukeForEntityClassification, @@ -4784,6 +4805,7 @@ if TYPE_CHECKING: from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel from .models.gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel + from .models.longt5 import FlaxLongT5ForConditionalGeneration, FlaxLongT5Model, FlaxLongT5PreTrainedModel from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel from .models.mbart import ( FlaxMBartForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 0818cebe175..bb0d7bccb74 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -76,6 +76,7 @@ from . import ( led, levit, longformer, + longt5, luke, lxmert, m2m_100, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 31e34125c65..7ce4b01b028 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -78,6 +78,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("led", "LEDConfig"), ("levit", "LevitConfig"), ("longformer", "LongformerConfig"), + ("longt5", "LongT5Config"), ("luke", "LukeConfig"), ("lxmert", "LxmertConfig"), ("m2m_100", "M2M100Config"), @@ -193,6 +194,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("levit", "LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("longt5", "LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -310,6 +312,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("led", "LED"), ("levit", "LeViT"), ("longformer", "Longformer"), + ("longt5", "LongT5"), ("luke", "LUKE"), ("lxmert", "LXMERT"), ("m2m_100", "M2M100"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 3f2854f9136..3b3bc90b60b 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -77,6 +77,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("led", "LEDModel"), ("levit", "LevitModel"), ("longformer", "LongformerModel"), + ("longt5", "LongT5Model"), ("luke", "LukeModel"), ("lxmert", "LxmertModel"), ("m2m_100", "M2M100Model"), @@ -217,6 +218,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ("layoutlm", "LayoutLMForMaskedLM"), ("led", "LEDForConditionalGeneration"), ("longformer", "LongformerForMaskedLM"), + ("longt5", "LongT5ForConditionalGeneration"), ("m2m_100", "M2M100ForConditionalGeneration"), ("marian", "MarianMTModel"), ("megatron-bert", "MegatronBertForCausalLM"), @@ -423,6 +425,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("encoder-decoder", "EncoderDecoderModel"), ("fsmt", "FSMTForConditionalGeneration"), ("led", "LEDForConditionalGeneration"), + ("longt5", "LongT5ForConditionalGeneration"), ("m2m_100", "M2M100ForConditionalGeneration"), ("marian", "MarianMTModel"), ("mbart", "MBartForConditionalGeneration"), diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 5eda5a74162..98c5d6fb5a1 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -41,6 +41,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ("gpt2", "FlaxGPT2Model"), ("gpt_neo", "FlaxGPTNeoModel"), ("gptj", "FlaxGPTJModel"), + ("longt5", "FlaxLongT5Model"), ("marian", "FlaxMarianModel"), ("mbart", "FlaxMBartModel"), ("mt5", "FlaxMT5Model"), @@ -65,6 +66,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ("bert", "FlaxBertForPreTraining"), ("big_bird", "FlaxBigBirdForPreTraining"), ("electra", "FlaxElectraForPreTraining"), + ("longt5", "FlaxLongT5ForConditionalGeneration"), ("mbart", "FlaxMBartForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"), ("roberta", "FlaxRobertaForMaskedLM"), @@ -98,6 +100,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("blenderbot", "FlaxBlenderbotForConditionalGeneration"), ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"), ("encoder-decoder", "FlaxEncoderDecoderModel"), + ("longt5", "FlaxLongT5ForConditionalGeneration"), ("marian", "FlaxMarianMTModel"), ("mbart", "FlaxMBartForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 49a1d7f8227..76a6dee790d 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -137,6 +137,13 @@ else: ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)), ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)), ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), + ( + "longt5", + ( + "T5Tokenizer" if is_sentencepiece_available() else None, + "T5TokenizerFast" if is_tokenizers_available() else None, + ), + ), ("luke", ("LukeTokenizer", None)), ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), diff --git a/src/transformers/models/longt5/__init__.py b/src/transformers/models/longt5/__init__.py new file mode 100644 index 00000000000..fd355f6d5a9 --- /dev/null +++ b/src/transformers/models/longt5/__init__.py @@ -0,0 +1,88 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available + + +_import_structure = { + "configuration_longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config", "LongT5OnnxConfig"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_longt5"] = [ + "LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST", + "LongT5EncoderModel", + "LongT5ForConditionalGeneration", + "LongT5Model", + "LongT5PreTrainedModel", + ] + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_longt5"] = [ + "FlaxLongT5ForConditionalGeneration", + "FlaxLongT5Model", + "FlaxLongT5PreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config, LongT5OnnxConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_longt5 import ( + LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST, + LongT5EncoderModel, + LongT5ForConditionalGeneration, + LongT5Model, + LongT5PreTrainedModel, + ) + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_longt5 import ( + FlaxLongT5ForConditionalGeneration, + FlaxLongT5Model, + FlaxLongT5PreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/longt5/configuration_longt5.py b/src/transformers/models/longt5/configuration_longt5.py new file mode 100644 index 00000000000..e120055bb46 --- /dev/null +++ b/src/transformers/models/longt5/configuration_longt5.py @@ -0,0 +1,178 @@ +# coding=utf-8 +# Copyright 2022, The LongT5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" LongT5 model configuration""" +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxSeq2SeqConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + +LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/LongT5-Local-Base": "https://huggingface.co/google/LongT5-Local-Base/blob/main/config.json", + "google/LongT5-Local-Large": "https://huggingface.co/google/LongT5-Local-Large/blob/main/config.json", + "google/LongT5-TGlobal-Base": "https://huggingface.co/google/LongT5-TGlobal-Base/blob/main/config.json", + "google/LongT5-TGlobal-Large": "https://huggingface.co/google/LongT5-TGlobal-Large/blob/main/config.json", +} + + +class LongT5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LongT5Model`] or a [`FlaxLongT5Model`]. It is + used to instantiate a LongT5 model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the LongT5 + [google/LongT5-Local-Base](https://huggingface.co/google/LongT5-Local-Base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 32128): + Vocabulary size of the LongT5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LongT5Model`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model // + num_heads`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `LongT5Block`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + local_radius (`int`, *optional*, defaults to 127) + Number of tokens to the left/right for each token to locally self-attend in a local attention mechanism. + global_block_size (`int`, *optional*, defaults to 16) + Lenght of blocks an input sequence is divided into for a global token representation. Used only for + `encoder_attention_type = "transient-global"`. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. LongT5v1.1 uses the + `"gated-gelu"` feed forward projection. Original LongT5 implementation uses `"gated-gelu"`. + encoder_attention_type (`string`, *optional*, defaults to `"local"`): + Type of encoder attention to be used. Should be one of `"local"` or `"transient-global"`, which are + supported by LongT5 implementation. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + model_type = "longt5" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=32128, + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + local_radius=127, + global_block_size=16, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="relu", + is_encoder_decoder=True, + encoder_attention_type="local", + use_cache=True, + pad_token_id=0, + eos_token_id=1, + **kwargs + ): + + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + # default = symmetry + self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers + self.num_heads = num_heads + self.local_radius = local_radius + self.global_block_size = global_block_size + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.encoder_attention_type = encoder_attention_type + self.use_cache = use_cache + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer." + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +class LongT5OnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py b/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py new file mode 100644 index 00000000000..41cc3a2005d --- /dev/null +++ b/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py @@ -0,0 +1,214 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert T5/LongT5X checkpoints from the original repository to JAX/FLAX model. This script is an extension of +'src/transformers/models/t5/convert_t5x_checkpoint_to_flax. +""" + +import argparse + +from t5x import checkpoints +from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM + + +def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path): + config = AutoConfig.from_pretrained(config_name) + flax_model = FlaxAutoModelForSeq2SeqLM.from_config(config=config) + t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) + + split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"] + + if config.model_type == "t5": + encoder_attn_name = "SelfAttention" + if config.model_type == "longt5" and config.encoder_attention_type == "local": + encoder_attn_name = "LocalSelfAttention" + elif config.model_type == "longt5" and config.encoder_attention_type == "transient-global": + encoder_attn_name = "TransientGlobalSelfAttention" + else: + raise ValueError( + "Given config is expected to have `model_type='t5'`, or `model_type='longt5` with `encoder_attention_type`" + " attribute with a value from ['local', 'transient-global]." + ) + + # Encoder + for layer_index in range(config.num_layers): + layer_name = f"layers_{str(layer_index)}" + + # Self-Attention + t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"] + t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"] + t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"] + t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"] + + # Global input layer norm + if config.model_type == "longt5" and config.encoder_attention_type == "transient-global": + t5x_global_layer_norm = t5x_model["target"]["encoder"][layer_name]["attention"]["T5LayerNorm_0"]["scale"] + + # Layer Normalization + t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"] + + if split_mlp_wi: + t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"] + t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"] + else: + t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"] + + t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] + + # Layer Normalization + t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + + # Assigning + flax_model_encoder_layer_block = flax_model.params["encoder"]["block"][str(layer_index)]["layer"] + flax_model_encoder_layer_block["0"][encoder_attn_name]["k"]["kernel"] = t5x_attention_key + flax_model_encoder_layer_block["0"][encoder_attn_name]["o"]["kernel"] = t5x_attention_out + flax_model_encoder_layer_block["0"][encoder_attn_name]["q"]["kernel"] = t5x_attention_query + flax_model_encoder_layer_block["0"][encoder_attn_name]["v"]["kernel"] = t5x_attention_value + + flax_model_encoder_layer_block["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm + + # Global input layer norm + if config.model_type == "longt5" and config.encoder_attention_type == "transient-global": + flax_model_encoder_layer_block["0"][encoder_attn_name]["global_input_layer_norm"][ + "weight" + ] = t5x_global_layer_norm + + if split_mlp_wi: + flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0 + flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1 + else: + flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi + + flax_model_encoder_layer_block["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo + flax_model_encoder_layer_block["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"] = flax_model_encoder_layer_block + + # Only for layer 0: + t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T + flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["relative_attention_bias"][ + "embedding" + ] = t5x_encoder_rel_embedding + + # Side/global relative position_bias + layer norm + if config.model_type == "longt5" and config.encoder_attention_type == "transient-global": + t5x_encoder_global_rel_embedding = t5x_model["target"]["encoder"]["side_relpos_bias"]["rel_embedding"].T + flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["global_relative_attention_bias"][ + "embedding" + ] = t5x_encoder_global_rel_embedding + + # Assigning + t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"] + flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm + + # Decoder + for layer_index in range(config.num_layers): + layer_name = f"layers_{str(layer_index)}" + + # Self-Attention + t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"] + t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"] + t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"] + t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"] + + # Layer Normalization + t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][ + "scale" + ] + + # Encoder-Decoder-Attention + t5x_enc_dec_attention_module = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"] + t5x_enc_dec_attention_key = t5x_enc_dec_attention_module["key"]["kernel"] + t5x_enc_dec_attention_out = t5x_enc_dec_attention_module["out"]["kernel"] + t5x_enc_dec_attention_query = t5x_enc_dec_attention_module["query"]["kernel"] + t5x_enc_dec_attention_value = t5x_enc_dec_attention_module["value"]["kernel"] + + # Layer Normalization + t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"] + + # MLP + if split_mlp_wi: + t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"] + t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"] + else: + t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"] + + t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] + + # Layer Normalization + tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + + # Assigning + flax_model_decoder_layer_block = flax_model.params["decoder"]["block"][str(layer_index)]["layer"] + flax_model_decoder_layer_block["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key + flax_model_decoder_layer_block["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out + flax_model_decoder_layer_block["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query + flax_model_decoder_layer_block["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value + + flax_model_decoder_layer_block["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm + + flax_model_decoder_layer_block["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key + flax_model_decoder_layer_block["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out + flax_model_decoder_layer_block["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query + flax_model_decoder_layer_block["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value + + flax_model_decoder_layer_block["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm + + if split_mlp_wi: + flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0 + flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1 + else: + flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi + + flax_model_decoder_layer_block["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo + + flax_model_decoder_layer_block["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"] = flax_model_decoder_layer_block + + # Decoder Normalization + tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"] + flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm + + # Only for layer 0: + t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T + flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ + "embedding" + ] = t5x_decoder_rel_embedding + + # Token Embeddings + tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"] + flax_model.params["shared"]["embedding"] = tx5_token_embeddings + + # LM Head (only in v1.1 and LongT5 checkpoints) + if "logits_dense" in t5x_model["target"]["decoder"]: + flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"] + + flax_model.save_pretrained(flax_dump_folder_path) + print("T5X Model was sucessfully converted!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the T5X checkpoint." + ) + parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of LongT5/T5 model.") + parser.add_argument( + "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." + ) + args = parser.parse_args() + convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path) diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py new file mode 100644 index 00000000000..db9e34f8e54 --- /dev/null +++ b/src/transformers/models/longt5/modeling_flax_longt5.py @@ -0,0 +1,2401 @@ +# coding=utf-8 +# Copyright 2022 LongT5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax LongT5 model.""" + + +import copy +from typing import Any, Callable, List, Optional, Tuple + +import numpy as np + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_longt5 import LongT5Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/LongT5-Local-Base" +_CONFIG_FOR_DOC = "LongT5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = np.zeros_like(input_ids) + shifted_input_ids[:, 1:] = input_ids[:, :-1] + shifted_input_ids[:, 0] = decoder_start_token_id + + shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +def _pad_to_multiple(x: jnp.ndarray, block_len: int, axis: int, pad_value: int = 0) -> jnp.ndarray: + """Pad an array so that a sequence length will be a multiple of `block_len`""" + pad_len = -x.shape[axis] % block_len + pad = [(0, 0)] * x.ndim + pad[axis] = (0, pad_len) + x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value) + return x + + +def _split_into_blocks(x: jnp.ndarray, block_len: int, axis: int) -> jnp.ndarray: + """Split an input array into blocks of a given `block_len` along the given `axis`. If the dimension length + is not a multiple of `block_len`, it will be padded first with selected `pad_value`. + """ + # pad tensor to multiple of block_len + if x.shape[axis] % block_len != 0: + x = _pad_to_multiple(x, block_len, axis, pad_value=0) + num_blocks = x.shape[axis] // block_len + output_shape = x.shape[:axis] + (num_blocks, block_len) + x.shape[(axis + 1) :] + return x.reshape(output_shape) + + +def _concatenate_3_blocks(x: jnp.ndarray, block_axis: int, sequence_axis: int, pad_value: int = 0) -> jnp.ndarray: + """Concatenate three consecutive blocks for each input block for local attentiont. + For more information, see: https://arxiv.org/pdf/2112.07916.pdf. + """ + num_blocks = x.shape[block_axis] + + pad = [(0, 0)] * x.ndim + pad[block_axis] = (1, 1) + # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len] + x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value) + + blocks_list: List[np.array] = [] + for i in range(3): + # We use indexing approach here: + # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs + indices = [slice(0, None)] * x.ndim + indices[block_axis] = slice(i, i + num_blocks) + indices = tuple(indices) + blocks_list.append(x[indices]) + return jnp.concatenate(blocks_list, axis=sequence_axis) # [batch_size, num_blocks, 3 * block_len, ...] + + +def _make_3block_relative_position_ids(block_len: int) -> jnp.ndarray: + """Makes 3-blocked relative position ids for local attention.""" + position_ids = jnp.arange(3 * block_len, dtype=jnp.int32) + center_position_ids = position_ids[block_len:-block_len] + relative_position_ids = position_ids[None, :] - center_position_ids[:, None] # [block_len, 3 * block_len] + return relative_position_ids + + +def _mask_local_attention_mask(local_attention_mask: np.ndarray, block_len: int) -> jnp.ndarray: + """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius.""" + relative_position_ids = _make_3block_relative_position_ids(block_len) + locality_mask = jnp.abs(relative_position_ids) < block_len + locality_mask = locality_mask[None, None, :, :] + return jnp.logical_and(local_attention_mask, locality_mask) + + +def _get_local_attention_mask(attention_mask: np.ndarray, block_len: int) -> jnp.ndarray: + """Prepare attention mask to be applied for a local attention.""" + # [batch_size, num_blocks, block_len] + _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, axis=1) + # [batch_size, num_block, 3 * block_len] + _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_axis=1, sequence_axis=2) + + _blocked_attention_mask = _blocked_attention_mask[..., None] + _3blocked_attention_mask = _3blocked_attention_mask[..., None, :] + # [batch_size, num_block, block_len, 3 * block_len] + local_attention_mask = jnp.logical_and(_blocked_attention_mask, _3blocked_attention_mask) + local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len) + # [batch_size, 1, num_block, block_len, 3 * block_len] + return local_attention_mask[:, None, ...] + + +def _make_global_fixed_block_ids(attention_mask: np.ndarray, global_block_size: int) -> Tuple[jnp.ndarray, np.ndarray]: + """Obtain the "fixed block" global id corresponding to each input token. + + This implementation is a simlified version of the original Flaxformr implementation adopted from: + https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py. + + In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for + the whole fixed block, are assigned to the preceding block. + + Padding tokens from the original sequence are represented by -1. + """ + batch_size, seq_len = attention_mask.shape[:2] + + def handle_orphan_tokens(block_ids: np.ndarray) -> jnp.ndarray: + block_ends = (jnp.arange(seq_len) % global_block_size) == global_block_size - 1 + true_block_ends = jnp.logical_and(block_ends, block_ids >= 0) + full_blocks = true_block_ends.sum(-1)[..., None] + block_ids = jnp.minimum(block_ids, full_blocks - 1) + return block_ids + + fixed_block_mask = jnp.ones_like(attention_mask) / global_block_size + fixed_block_mask = jnp.cumsum(fixed_block_mask, axis=1) - fixed_block_mask + mask = jnp.where(attention_mask != 0.0, 1.0, -1000.0) + global_block_ids = jnp.maximum( + jnp.floor(mask + fixed_block_mask - 1.0), jnp.array(-1.0, dtype=attention_mask.dtype) + ) + # set padding tokens to -1 + global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1) + # [batch_size, seq_len] + global_block_ids = handle_orphan_tokens(global_block_ids) + num_globals = seq_len // global_block_size + + # [batch_size, seq_len // global_block_size] + if num_globals > 0: + _sequence_block_ids_max = jnp.repeat(global_block_ids.max(axis=-1)[:, None], repeats=num_globals, axis=1) + else: + _sequence_block_ids_max = jnp.zeros((batch_size, 0), dtype=global_block_ids.dtype) + global_segment_ids = jnp.cumsum(jnp.ones((batch_size, num_globals)), axis=-1) - 1 + global_segment_ids = jnp.where(global_segment_ids <= _sequence_block_ids_max, 1, 0) + return global_block_ids, global_segment_ids + + +def _make_side_relative_position_ids(attention_mask: np.ndarray, global_block_size: int) -> np.ndarray: + """Create the relative position tensor for local -> global attention.""" + block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size) + global_seq_len = global_segment_ids.shape[-1] + global_positions = jnp.arange(global_seq_len) + side_relative_position = global_positions - block_ids[..., None] + return side_relative_position + + +def _create_global_aggregates(hidden_states: np.ndarray, block_ids: np.ndarray, global_seq_len: int) -> np.ndarray: + """Compute individual block aggregates by summing over individual blocks.""" + # (batch..., seq_len, global_seq_len)) + one_hot_block_ids = jax.nn.one_hot(block_ids, global_seq_len) + return jnp.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids) + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerNorm with T5->LongT5 +class FlaxLongT5LayerNorm(nn.Module): + hidden_size: int + dtype: jnp.dtype = jnp.float32 + eps: float = 1e-6 + weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones + + def setup(self): + self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) + + def __call__(self, hidden_states): + """ + Construct a layernorm module in the LongT5 style; No bias and no subtraction of mean. + """ + # layer norm should always be calculated in float32 + variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) + hidden_states = hidden_states / jnp.sqrt(variance + self.eps) + + return self.weight * hidden_states + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseActDense with T5->LongT5 +class FlaxLongT5DenseActDense(nn.Module): + config: LongT5Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wo_init_std), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseGatedActDense with T5->LongT5 +class FlaxLongT5DenseGatedActDense(nn.Module): + config: LongT5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi_0 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wi_1 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wo_init_std), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerFF with T5->LongT5 +class FlaxLongT5LayerFF(nn.Module): + config: LongT5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.is_gated_act: + self.DenseReluDense = FlaxLongT5DenseGatedActDense(self.config, dtype=self.dtype) + else: + self.DenseReluDense = FlaxLongT5DenseActDense(self.config, dtype=self.dtype) + + self.layer_norm = FlaxLongT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__(self, hidden_states, deterministic=True): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) + hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) + return hidden_states + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention with T5->LongT5 +class FlaxLongT5Attention(nn.Module): + config: LongT5Config + has_relative_attention_bias: bool = False + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + + self.q = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(q_init_std), + dtype=self.dtype, + ) + self.k = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.v = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.o = nn.Dense( + self.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(o_init_std), + dtype=self.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embed( + self.relative_attention_num_buckets, + self.n_heads, + embedding_init=jax.nn.initializers.normal(kv_init_std), + ) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + relative_position = -jnp.clip(relative_position, a_max=0) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = jnp.arange(query_length, dtype="i4")[:, None] + memory_position = jnp.arange(key_length, dtype="i4")[None, :] + + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.causal), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + + values = self.relative_attention_bias(relative_position_bucket) + values = values.transpose((2, 0, 1))[None, :, :, :] + return values + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) + value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions + # that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def _create_position_bias( + self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ): + cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) + key_length = key_states.shape[1] + query_length = key_length if cache_is_filled else query_states.shape[1] + + if self.has_relative_attention_bias: + position_bias = self.compute_bias(query_length, key_length) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) + + # if key and values are already calculated, only the last query position bias should be taken + if cache_is_filled: + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + position_bias = jax.lax.dynamic_slice( + position_bias, + (0, 0, causal_attention_mask_shift, 0), + (1, self.n_heads, seq_length, max_decoder_length), + ) + return position_bias + + def __call__( + self, + hidden_states, + attention_mask=None, + key_value_states=None, + position_bias=None, + use_cache=False, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + batch_size, seq_length = hidden_states.shape[:2] + + # q, k, v projections + query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) + key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) + value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # counter-act scaling in dot_product_attention_weights function + query_states *= jnp.sqrt(query_states.shape[-1]) + + # for fast decoding causal attention mask should be shifted + causal_attention_mask_shift = ( + self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 + ) + # create causal attention_mask; attention_mask has to be defined when model is causal + if self.causal: + causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") + + # fast decoding for generate requires special attention_mask + if self.has_variable("cache", "cached_key"): + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_attention_mask = jax.lax.dynamic_slice( + causal_attention_mask, + (0, 0, causal_attention_mask_shift, 0), + (1, 1, seq_length, max_decoder_length), + ) + + # broadcast causal attention mask & attention mask to fit for merge + causal_attention_mask = jnp.broadcast_to( + causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] + ) + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_attention_mask) + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # replace masked positions with -10_000 + if attention_mask is not None: + attention_mask = jax.lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e4).astype(self.dtype), + ) + + if position_bias is None: + # compute position bias (only for first layer) + position_bias = self._create_position_bias( + key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ) + + if attention_mask is not None: + position_bias = position_bias + attention_mask + + # create dropout rng + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # Softmax(QK^T) + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=position_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + ) + + # multiply with value states + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + # bring back to (batch_size, seq_length, d_model) + attn_output = self._merge_heads(attn_output) + + # apply output matrix + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + + return outputs + + +class FlaxLongT5LocalAttention(nn.Module): + config: LongT5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.local_radius = self.config.local_radius + self.block_len = self.local_radius + 1 + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + + self.q = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(q_init_std), + dtype=self.dtype, + ) + self.k = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.v = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.o = nn.Dense( + self.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(o_init_std), + dtype=self.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embed( + self.relative_attention_num_buckets, + self.n_heads, + embedding_init=jax.nn.initializers.normal(kv_init_std), + ) + + @staticmethod + # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + relative_position = -jnp.clip(relative_position, a_max=0) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + def compute_bias(self, block_length: int): + """Compute binned relative position bias""" + memory_position = jnp.arange(3 * block_length, dtype="i4") + context_position = memory_position[block_length:-block_length] + + relative_position = memory_position[None, :] - context_position[:, None] + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + + values = self.relative_attention_bias(relative_position_bucket) + values = values.transpose((2, 0, 1))[None, None, :, :, :] + return values + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim) + + def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray: + # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len) + if self.has_relative_attention_bias: + position_bias = self.compute_bias(block_len) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype) + + return position_bias + + def __call__( + self, + hidden_states, + attention_mask=None, + key_value_states=None, + position_bias=None, + output_attentions=False, + deterministic=True, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + batch_size, seq_length = hidden_states.shape[:2] + + # q, k, v projections + query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) + key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) + value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, head_dim) + query_states = _split_into_blocks(query_states, self.block_len, axis=1) + key_states = _split_into_blocks(key_states, self.block_len, axis=1) + value_states = _split_into_blocks(value_states, self.block_len, axis=1) + + # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head) + key_states = _concatenate_3_blocks(key_states, block_axis=1, sequence_axis=2) + value_states = _concatenate_3_blocks(value_states, block_axis=1, sequence_axis=2) + + # counter-act scaling in dot_product_attention_weights function + query_states *= jnp.sqrt(query_states.shape[-1]) + + if attention_mask is not None: + attention_mask = _get_local_attention_mask(attention_mask, self.block_len) + + # replace masked positions with -10_000 + attention_mask = jax.lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + ) + + if position_bias is None: + # compute position bias (only for first layer) + position_bias = self._create_position_bias(self.block_len, attention_mask) + + if attention_mask is not None: + position_bias = position_bias + attention_mask.swapaxes(1, 2) + + # create dropout rng + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # Softmax(QK^T) + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=position_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + ) + + # multiply with value states + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + # bring back to (batch_size, seq_length, d_model) + attn_output = self._merge_heads(attn_output) + attn_output = attn_output[:, :seq_length, :] + + # apply output matrix + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + + return outputs + + +class FlaxLongT5TransientGlobalAttention(nn.Module): + config: LongT5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.local_radius = self.config.local_radius + self.block_len = self.local_radius + 1 + self.global_block_size = self.config.global_block_size + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + + self.q = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(q_init_std), + dtype=self.dtype, + ) + self.k = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.v = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.o = nn.Dense( + self.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(o_init_std), + dtype=self.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embed( + self.relative_attention_num_buckets, + self.n_heads, + embedding_init=jax.nn.initializers.normal(kv_init_std), + ) + + # Relativen attention bias & Layer norm for global attention + if self.has_relative_attention_bias: + self.global_relative_attention_bias = nn.Embed( + self.relative_attention_num_buckets, + self.n_heads, + embedding_init=jax.nn.initializers.normal(kv_init_std), + ) + self.global_input_layer_norm = FlaxLongT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + + @staticmethod + # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + relative_position = -jnp.clip(relative_position, a_max=0) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + def compute_bias(self, block_length: int): + """Compute binned relative position bias""" + memory_position = jnp.arange(3 * block_length, dtype="i4") + context_position = memory_position[block_length:-block_length] + + relative_position = memory_position[None, :] - context_position[:, None] + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + + values = self.relative_attention_bias(relative_position_bucket) + values = values.transpose((2, 0, 1))[None, None, :, :, :] + return values + + def compute_side_bias(self, attention_mask: np.ndarray, global_segment_ids: np.ndarray) -> np.ndarray: + # (batch_size, 1, 1, seq_len, global_seq_len) + side_attention_mask = jnp.equal(attention_mask[..., None], global_segment_ids[:, None, :])[:, None, ...] + attention_side_bias = jax.lax.select( + side_attention_mask > 0, + jnp.full(side_attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(side_attention_mask.shape, -1e10).astype(self.dtype), + ) + # (batch_size, seq_len, global_seq_len) + side_relative_position = _make_side_relative_position_ids(attention_mask, self.global_block_size) + side_relative_position_bucket = self._relative_position_bucket( + side_relative_position, + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + # (batch_size, seq_len, global_seq_len, num_heads) + side_bias = self.global_relative_attention_bias(side_relative_position_bucket) + + # (batch_size, 1, num_heads, seq_len, global_seq_len) + side_bias = jnp.transpose(side_bias, (0, 3, 1, 2)) + # (batch_size, num_heads, seq_len, global_seq_len) + attention_side_bias = attention_side_bias + side_bias + return attention_side_bias + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim) + + def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray: + # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len) + if self.has_relative_attention_bias: + position_bias = self.compute_bias(block_len) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype) + + return position_bias + + def __call__( + self, + hidden_states, + attention_mask=None, + key_value_states=None, + position_bias=None, + output_attentions=False, + deterministic=True, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + batch_size, seq_length = hidden_states.shape[:2] + + # Prepare components for transient-global attention + # Obtain block_ids and global_segment_ids + # global_seq_len := seq_len // self.global_block_size + # shapes: (batch_size, seq_len) & (batch_size, global_seq_len) + block_ids, global_segment_ids = _make_global_fixed_block_ids( + attention_mask if attention_mask is not None else jnp.ones((batch_size, seq_length)), + self.global_block_size, + ) + # Create global inputs + _global_seq_len = global_segment_ids.shape[-1] + global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len) + global_inputs = self.global_input_layer_norm(global_inputs) + + # q, k, v projections + query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) + key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) + value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # Get global/side key/value_states + side_key_states = self.k(global_inputs) + side_value_states = self.v(global_inputs) + + # reshape to (batch_size, global_seq_len, n_heads, head_dim) + side_key_states = self._split_heads(side_key_states) + side_value_states = self._split_heads(side_value_states) + + # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, head_dim) + query_states = _split_into_blocks(query_states, self.block_len, axis=1) + key_states = _split_into_blocks(key_states, self.block_len, axis=1) + value_states = _split_into_blocks(value_states, self.block_len, axis=1) + + # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head) + key_states = _concatenate_3_blocks(key_states, block_axis=1, sequence_axis=2) + value_states = _concatenate_3_blocks(value_states, block_axis=1, sequence_axis=2) + + # Tile side inputs across local key/value blocks + # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head) + reps = [1] * (side_key_states.ndim + 1) + reps[1] = key_states.shape[1] + side_key_states = jnp.tile(side_key_states[:, None, ...], reps) + side_value_states = jnp.tile(side_value_states[:, None, ...], reps) + + # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones + # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head) + key_states = jnp.concatenate((key_states, side_key_states), axis=2) + value_states = jnp.concatenate((value_states, side_value_states), axis=2) + + # counter-act scaling in dot_product_attention_weights function + query_states *= jnp.sqrt(query_states.shape[-1]) + + if attention_mask is not None: + local_attention_mask = _get_local_attention_mask(attention_mask, self.block_len) + local_attention_mask = jax.lax.select( + local_attention_mask > 0, + jnp.full(local_attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(local_attention_mask.shape, -1e10).astype(self.dtype), + ) + else: + local_attention_mask = None + + if position_bias is None: + # compute position bias (only for first layer) + position_bias = self._create_position_bias(self.block_len, attention_mask) + if local_attention_mask is not None: + position_bias = position_bias + local_attention_mask.swapaxes(1, 2) + + # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len) + if attention_mask is None: + attention_mask = jnp.ones((batch_size, seq_length)) + side_position_bias = self.compute_side_bias(attention_mask, global_segment_ids) + side_position_bias = _split_into_blocks(side_position_bias, self.block_len, axis=-2) + side_position_bias = jnp.swapaxes(side_position_bias, 1, 2) + position_bias = jnp.concatenate((position_bias, side_position_bias), axis=-1) + + # create dropout rng + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # Softmax(QK^T) + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=position_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + ) + + # multiply with value states + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + # bring back to (batch_size, seq_length, d_model) + attn_output = self._merge_heads(attn_output) + attn_output = attn_output[:, :seq_length, :] + + # apply output matrix + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + + return outputs + + +class FlaxLongT5LayerLocalSelfAttention(nn.Module): + """Local self attention used in encoder""" + + config: LongT5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.LocalSelfAttention = FlaxLongT5LocalAttention( + self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype + ) + self.layer_norm = FlaxLongT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + **kwargs: Any, # to accept init_cache kwargs + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.LocalSelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxLongT5LayerTransientGlobalSelfAttention(nn.Module): + """Transient-Global self attention used in encoder""" + + config: LongT5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.TransientGlobalSelfAttention = FlaxLongT5TransientGlobalAttention( + self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype + ) + self.layer_norm = FlaxLongT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + **kwargs: Any, # to accept init_cache kwargs + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.TransientGlobalSelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerSelfAttention with T5->LongT5 +class FlaxLongT5LayerSelfAttention(nn.Module): + config: LongT5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.SelfAttention = FlaxLongT5Attention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + causal=self.config.causal, + dtype=self.dtype, + ) + self.layer_norm = FlaxLongT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCrossAttention with T5->LongT5 +class FlaxLongT5LayerCrossAttention(nn.Module): + config: LongT5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.EncDecAttention = FlaxLongT5Attention( + self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype + ) + self.layer_norm = FlaxLongT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + attention_mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxLongT5Block(nn.Module): + config: LongT5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.causal = self.config.causal + if self.causal: + attention_layer = FlaxLongT5LayerSelfAttention + elif self.config.encoder_attention_type == "local": + attention_layer = FlaxLongT5LayerLocalSelfAttention + elif self.config.encoder_attention_type == "transient-global": + attention_layer = FlaxLongT5LayerTransientGlobalSelfAttention + else: + raise ValueError( + "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, " + f"but got {self.config.encoder_attention_type}." + ) + self.layer = ( + attention_layer( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + name=str(0), + dtype=self.dtype, + ), + ) + feed_forward_index = 1 + if self.causal: + self.layer += (FlaxLongT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) + feed_forward_index += 1 + + self.layer += (FlaxLongT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) + + # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Block.__call__ with T5->LongT5 + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + return_dict=True, + deterministic=True, + init_cache=False, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + do_cross_attention = self.causal and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = cross_attention_outputs[0] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + outputs = outputs + attention_outputs + + # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + return outputs + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCollection with T5->LongT5 +class FlaxLongT5LayerCollection(nn.Module): + config: LongT5Config + has_relative_attention_bias: bool + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxLongT5Block( + self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype + ) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + return_dict=True, + deterministic=True, + init_cache=False, + ): + return self.layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5BlockCollection with T5->LongT5 +class FlaxLongT5BlockCollection(nn.Module): + config: LongT5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.causal = self.config.causal + self.blocks = [ + FlaxLongT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i)) + for i in range(self.config.num_layers) + ] + + def __call__( + self, + hidden_states=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + deterministic: bool = True, + init_cache: bool = False, + ): + # Prepare head mask if needed + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.causal) else None + position_bias = None + encoder_decoder_position_bias = None + + for i, layer_module in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + + if self.causal and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if self.causal: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Stack with T5->LongT5 +class FlaxLongT5Stack(nn.Module): + config: LongT5Config + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.causal = self.config.causal + + self.block = FlaxLongT5BlockCollection(self.config, dtype=self.dtype) + self.final_layer_norm = FlaxLongT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + init_cache: bool = False, + ): + hidden_states = self.embed_tokens(input_ids) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + outputs = self.block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + deterministic=deterministic, + init_cache=init_cache, + ) + + hidden_states = outputs[0] + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + # Add last layer + all_hidden_states = None + + if output_hidden_states: + all_hidden_states = outputs.hidden_states + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + if output_hidden_states: + return ( + hidden_states, + all_hidden_states, + ) + outputs[2:] + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +LONGT5_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so + you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5 + Training](./longt5#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +LONGT5_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For training, `decoder_input_ids` should be provided. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +LONGT5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so + you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5 + Training](./longt5#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5 + Training](./longt5#training). + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxLongT5PreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LongT5Config + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: LongT5Config, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + + attention_mask = jnp.ones_like(input_ids) + decoder_input_ids = jnp.ones_like(input_ids) + decoder_attention_mask = jnp.ones_like(input_ids) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: jnp.ndarray = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if decoder_input_ids is None: + raise ValueError( + "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" + " here." + ) + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # prepare decoder inputs + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(LONGT5_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=LongT5Config) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base") + >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/LongT5-Local-Base") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=LongT5Config) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base") + >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/LongT5-Local-Base") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxLongT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + +LONGT5_START_DOCSTRING = r""" + The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long + Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo + Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising + generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different + efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention. + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`LongT5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + + +@add_start_docstrings( + "The bare LONGT5 Model transformer outputting raw hidden-stateswithout any specific head on top.", + LONGT5_START_DOCSTRING, +) +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Module with T5->LongT5 +class FlaxLongT5Module(nn.Module): + config: LongT5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.causal = False + self.encoder = FlaxLongT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype) + + decoder_config = copy.deepcopy(self.config) + decoder_config.causal = True + decoder_config.num_layers = self.config.num_decoder_layers + self.decoder = FlaxLongT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Model with T5->LongT5 +class FlaxLongT5Model(FlaxLongT5PreTrainedModel): + module_class = FlaxLongT5Module + + +append_call_sample_docstring( + FlaxLongT5Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC +) + +FLAX_LONGT5_MODEL_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, FlaxLongT5Model + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base") + >>> model = FlaxLongT5Model.from_pretrained("google/LongT5-Local-Base") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="np" + ... ).input_ids + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + + +overwrite_call_docstring(FlaxLongT5Model, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxLongT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + + +@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING) +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule with T5->LongT5 +class FlaxLongT5ForConditionalGenerationModule(nn.Module): + config: LongT5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.model_dim = self.config.d_model + + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.causal = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = FlaxLongT5Stack(encoder_config, self.shared, dtype=self.dtype) + + decoder_config = copy.deepcopy(self.config) + decoder_config.causal = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = self.config.num_decoder_layers + self.decoder = FlaxLongT5Stack(decoder_config, self.shared, dtype=self.dtype) + + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = self.shared.variables["params"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = self.lm_head(sequence_output) + + if not return_dict: + return (lm_logits,) + decoder_outputs[1:] + encoder_outputs + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel): + module_class = FlaxLongT5ForConditionalGenerationModule + + @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=LongT5Config) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base") + >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/LongT5-Local-Base") + + >>> text = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxLongT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + decoder_outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = module.shared.variables["params"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = module.lm_head(sequence_output) + + return lm_logits, decoder_outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jnp.DeviceArray] = None, + decoder_attention_mask: Optional[jnp.DeviceArray] = None, + encoder_outputs=None, + **kwargs + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + extended_attention_mask = jax.lax.dynamic_update_slice( + extended_attention_mask, decoder_attention_mask, (0, 0) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + return model_kwargs + + +FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base") + >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/LongT5-Local-Base") + + >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"]).sequences + >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + + +overwrite_call_docstring( + FlaxLongT5ForConditionalGeneration, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxLongT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py new file mode 100644 index 00000000000..292a54c64bb --- /dev/null +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -0,0 +1,2192 @@ +# coding=utf-8 +# Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LongT5 model.""" + + +import copy +import math +import warnings +from typing import Any, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + logging, + replace_return_docstrings, +) +from .configuration_longt5 import LongT5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LongT5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" +_CHECKPOINT_FOR_DOC = "google/LongT5-Local-Base" + +# TODO: Update before the merge +LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/LongT5-Local-Base", + "google/LongT5-Local-Large", + "google/LongT5-TGlobal-Base", + "google/LongT5-TGlobal-Large", +] + + +def _pad_to_multiple(x: torch.Tensor, block_len: int, dim: int, pad_value: int = 0) -> torch.Tensor: + """Pad a tensor so that a sequence length will be a multiple of `block_len`""" + pad_len = -x.shape[dim] % block_len + # Handle cases when an empty input sequence is given + if not all(x.shape): + new_shape = list(x.shape) + new_shape[dim] += pad_len + return torch.zeros(new_shape, dtype=x.dtype) + + pad = [(0, 0)] * x.ndim + pad[dim] = (0, pad_len) + pad = sum(pad[::-1], ()) + x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value) + return x + + +def _split_into_blocks(x: torch.Tensor, block_len: int, dim: int) -> torch.Tensor: + """Split an input tensor into blocks of a given `block_len` along the given `dim`. If the dimension length + is not a multiple of `block_len`, it will be padded first with selected `pad_value`. + """ + # pad tensor to multiple of block_len + if x.shape[dim] % block_len != 0: + x = _pad_to_multiple(x, block_len, dim, pad_value=0) + num_blocks = x.shape[dim] // block_len + output_shape = x.shape[:dim] + (num_blocks, block_len) + x.shape[(dim + 1) :] + # If 0 is in output_shape, we cannot apply reshape because of incompatibility with ONNX conversion + if 0 in output_shape: + return torch.empty(output_shape, dtype=x.dtype, device=x.device) + return x.reshape(output_shape) + + +def _concatenate_3_blocks(x: torch.Tensor, block_dim: int, sequence_dim: int, pad_value: int = 0) -> torch.Tensor: + """Concatenate three consecutive blocks for each input block for local attentiont. + + For more information, see: https://arxiv.org/pdf/2112.07916.pdf. + """ + num_blocks = x.shape[block_dim] + + pad = [(0, 0)] * x.ndim + pad[block_dim] = (1, 1) + pad = sum(pad[::-1], ()) + # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len] + x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value) + + blocks_list: List[torch.Tensor] = [] + for i in range(3): + # We use indexing approach here: + # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs + indices = [slice(0, None)] * x.ndim + indices[block_dim] = slice(i, i + num_blocks) + indices = tuple(indices) + blocks_list.append(x[indices]) + # [batch_size, num_blocks, 3 * block_len, ...] + return torch.cat(blocks_list, dim=sequence_dim) + + +def _make_3block_relative_position_ids(block_len: int) -> torch.Tensor: + """Makes 3-blocked relative position ids for local attention.""" + position_ids = torch.arange(3 * block_len, dtype=torch.int32) + center_position_ids = position_ids[block_len:-block_len] + # [block_len, 3 * block_len] + relative_position_ids = position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1) + return relative_position_ids + + +def _mask_local_attention_mask(local_attention_mask: torch.Tensor, block_len: int) -> torch.Tensor: + """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius.""" + relative_position_ids = _make_3block_relative_position_ids(block_len) + locality_mask = torch.abs(relative_position_ids) < block_len + locality_mask = locality_mask[None, None, :, :] + locality_mask = locality_mask.to(local_attention_mask.device) + return torch.logical_and(local_attention_mask, locality_mask) + + +def _get_local_attention_mask(attention_mask: torch.Tensor, block_len: int, device: torch.device) -> torch.Tensor: + """Prepare attention mask to be applied for a local attention.""" + # [batch_size, num_blocks, block_len] + _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, dim=1) + # [batch_size, num_block, 3 * block_len] + _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_dim=1, sequence_dim=2) + + _blocked_attention_mask = _blocked_attention_mask.unsqueeze(-1) + _3blocked_attention_mask = _3blocked_attention_mask.unsqueeze(-2) + # [batch_size, num_block, block_len, 3 * block_len] + local_attention_mask = torch.logical_and(_blocked_attention_mask, _3blocked_attention_mask) + local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len) + # [batch_size, 1, num_block, block_len, 3 * block_len] + return local_attention_mask.unsqueeze(1).to(device) + + +def _make_global_fixed_block_ids( + attention_mask: torch.Tensor, global_block_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """Obtain the "fixed block" global id corresponding to each input token. + + This implementation is a simlified version of the original Flaxformr implementation adopted from: + https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py. + + In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for + the whole fixed block, are assigned to the preceding block. + + Padding tokens from the original sequence are represented by -1. + """ + batch_size, seq_len = attention_mask.shape[:2] + + def handle_orphan_tokens(block_ids: torch.Tensor) -> torch.Tensor: + block_ends = (torch.arange(seq_len) % global_block_size) == global_block_size - 1 + block_ends = block_ends.to(block_ids.device) + true_block_ends = torch.logical_and(block_ends, block_ids >= 0) + full_blocks = true_block_ends.sum(-1).unsqueeze(-1).type(block_ids.dtype) - 1 + block_ids = torch.where(block_ids < full_blocks, block_ids, full_blocks) + return block_ids + + fixed_block_mask = torch.ones_like(attention_mask, device=attention_mask.device) / global_block_size + fixed_block_mask = torch.cumsum(fixed_block_mask, axis=1) - fixed_block_mask + mask = torch.where(attention_mask != 0.0, 1.0, -1000.0).type(attention_mask.dtype) + global_block_ids = torch.floor(mask + fixed_block_mask - 1.0).type(attention_mask.dtype) + _global_block_ids_lower_bound = torch.tensor(-1.0, dtype=global_block_ids.dtype, device=global_block_ids.device) + global_block_ids = torch.where( + global_block_ids > _global_block_ids_lower_bound, global_block_ids, _global_block_ids_lower_bound + ) + # set padding tokens to -1 + global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1) + # [batch_size, seq_len] + global_block_ids = handle_orphan_tokens(global_block_ids) + num_globals = seq_len // global_block_size + # [batch_size, seq_len // global_block_size] + if num_globals > 0: + _sequence_block_ids_max = torch.max(global_block_ids, dim=-1).values.repeat(num_globals, 1).transpose(0, 1) + else: + _sequence_block_ids_max = torch.zeros( + batch_size, 0, dtype=global_block_ids.dtype, device=global_block_ids.device + ) + global_segment_ids = torch.cumsum(torch.ones(batch_size, num_globals), dim=-1) - 1 + global_segment_ids = global_segment_ids.to(attention_mask.device) + global_segment_ids = torch.where(global_segment_ids <= _sequence_block_ids_max, 1, 0) + return global_block_ids.type(torch.int), global_segment_ids.type(torch.int) + + +def _make_side_relative_position_ids(attention_mask: torch.Tensor, global_block_size: int) -> torch.Tensor: + """Create the relative position tensor for local -> global attention.""" + block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size) + global_seq_len = global_segment_ids.shape[-1] + global_positions = torch.arange(global_seq_len, device=block_ids.device) + side_relative_position = global_positions - block_ids[..., None] + return side_relative_position.type(torch.int64) + + +def _create_global_aggregates( + hidden_states: torch.Tensor, block_ids: torch.Tensor, global_seq_len: int +) -> torch.Tensor: + """Compute individual block aggregates by summing over individual blocks.""" + # (batch..., seq_len, global_seq_len)) + block_ids = block_ids.where( + block_ids >= 0, torch.tensor(global_seq_len, dtype=block_ids.dtype, device=block_ids.device) + ) + one_hot_block_ids = nn.functional.one_hot(block_ids.type(torch.int64), global_seq_len + 1)[:, :, :-1] + return torch.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids.type(hidden_states.dtype)) + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->LongT5 +class LongT5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the LongT5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + + # LongT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + LongT5LayerNorm = FusedRMSNorm # noqa + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of LongT5LayerNorm") +except ImportError: + # using the normal LongT5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm") + pass + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5 +class LongT5DenseActDense(nn.Module): + def __init__(self, config: LongT5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->LongT5 +class LongT5DenseGatedActDense(nn.Module): + def __init__(self, config: LongT5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->LongT5 +class LongT5LayerFF(nn.Module): + def __init__(self, config: LongT5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = LongT5DenseGatedActDense(config) + else: + self.DenseReluDense = LongT5DenseActDense(config) + + self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5 +class LongT5Attention(nn.Module): + def __init__(self, config: LongT5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores += position_bias + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class LongT5LocalAttention(nn.Module): + def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None: + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.local_radius = config.local_radius + self.block_len = self.local_radius + 1 + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + # Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, block_length: int): + """Compute binned relative position bias""" + memory_position = torch.arange( + 3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device + ) + context_position = memory_position[block_length:-block_length] + + # (block_length, 3 * block_length) + relative_position = memory_position[None, :] - context_position[:, None] + relative_position_bucket = self._relative_position_bucket( + relative_position, # (block_length, 3 * block_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + # (block_length, 3 * block_length, num_heads) + values = self.relative_attention_bias(relative_position_bucket) + # (1, 1, num_heads, block_length, 3 * block_length) + values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0) + return values + + def forward( + self, + hidden_states, + mask=None, + position_bias=None, + layer_head_mask=None, + output_attentions=False, + ): + batch_size, seq_length = hidden_states.shape[:2] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + + def unshape(states): + """reshape""" + return states.contiguous().view(batch_size, -1, self.inner_dim) + + # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head) + query_states = shape(self.q(hidden_states)) + key_states = shape(self.k(hidden_states)) + value_states = shape(self.v(hidden_states)) + + # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head) + query_states = _split_into_blocks(query_states, self.block_len, dim=1) + key_states = _split_into_blocks(key_states, self.block_len, dim=1) + value_states = _split_into_blocks(value_states, self.block_len, dim=1) + + # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head) + key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2) + value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2) + + # Compute scores + scores = torch.einsum( + "...qhd,...khd->...hqk", query_states, key_states + ) # (batch_size, num_block, n_heads, block_len, 3 * block_len) + + if position_bias is None: + # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len) + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, 1, self.n_heads, self.block_len, 3 * self.block_len), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(self.block_len) + + if mask is not None: + # Replace masked positions with -1e10 (according to the original implementation) + mask = torch.where(mask > 0, 0.0, -1e10) + # We need to adjust position bias shape to be sum with mask + position_bias = position_bias + mask.transpose(1, 2) + + scores += position_bias + # (batch_size, num_blocks, n_heads, block_len, 3 * block_len) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + # (batch_size, num_blocks, n_heads, block_len, 3 * block_len) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + attn_weights = attn_weights.type(value_states.dtype) + attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states)) + attn_output = attn_output[:, :seq_length, :] + attn_output = self.o(attn_output) + + present_key_value_state = None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class LongT5TransientGlobalAttention(nn.Module): + def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None: + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.local_radius = config.local_radius + self.block_len = self.local_radius + 1 + self.global_block_size = config.global_block_size + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + # Relativen attention bias & Layer norm for global attention + if self.has_relative_attention_bias: + self.global_relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.global_input_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + # Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, block_length: int): + """Compute binned relative position bias""" + memory_position = torch.arange( + 3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device + ) + context_position = memory_position[block_length:-block_length] + + # (block_length, 3 * block_length) + relative_position = memory_position[None, :] - context_position[:, None] + relative_position_bucket = self._relative_position_bucket( + relative_position, # (block_length, 3 * block_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + # (block_length, 3 * block_length, num_heads) + values = self.relative_attention_bias(relative_position_bucket) + # (1, 1, num_heads, block_length, 3 * block_length) + values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0) + return values + + def compute_side_bias(self, mask: torch.Tensor, global_segment_ids: torch.Tensor) -> torch.Tensor: + # (batch_size, 1, seq_len, global_seq_len) + side_attention_mask = torch.eq(mask[..., None], global_segment_ids[:, None, :])[:, None, ...] + attention_side_bias = torch.where(side_attention_mask > 0, 0.0, -1e10) + # (batch_size, seq_len, global_seq_len) + side_relative_position = _make_side_relative_position_ids(mask, self.global_block_size) + side_relative_position_bucket = self._relative_position_bucket( + side_relative_position, + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + # (batch_size, seq_len, global_seq_len, num_heads) + side_bias = self.global_relative_attention_bias(side_relative_position_bucket) + + # (batch_size, num_heads, seq_len, global_seq_len) + side_bias = side_bias.permute([0, 3, 1, 2]) + # (batch_size, num_heads, seq_len, global_seq_len) + attention_side_bias = attention_side_bias + side_bias + return attention_side_bias + + def forward( + self, + hidden_states, + mask=None, + position_bias=None, + layer_head_mask=None, + output_attentions=False, + ): + batch_size, seq_length = hidden_states.shape[:2] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + + def unshape(states): + """reshape""" + return states.contiguous().view(batch_size, -1, self.inner_dim) + + # Prepare components for transient-global attention + # Obtain block_ids and global_segment_ids + # global_seq_len := seq_len // self.global_block_size + # shapes: (batch_size, seq_len) & (batch_size, global_seq_len) + block_ids, global_segment_ids = _make_global_fixed_block_ids( + mask if mask is not None else torch.ones(hidden_states.shape[:-1]), + self.global_block_size, + ) + # Create global inputs + _global_seq_len = global_segment_ids.shape[-1] + global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len) + global_inputs = self.global_input_layer_norm(global_inputs) + + # get query states -> (batch_size, seq_length, n_heads, dim_per_head) + query_states = shape(self.q(hidden_states)) + key_states = shape(self.k(hidden_states)) + value_states = shape(self.v(hidden_states)) + # Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head) + side_key_states = shape(self.k(global_inputs)) + side_value_states = shape(self.v(global_inputs)) + + # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head) + query_states = _split_into_blocks(query_states, self.block_len, dim=1) + key_states = _split_into_blocks(key_states, self.block_len, dim=1) + value_states = _split_into_blocks(value_states, self.block_len, dim=1) + + # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head) + key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2) + value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2) + + # Tile side inputs across local key/value blocks + # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head) + reps = [1] * (side_key_states.ndim + 1) + reps[1] = key_states.shape[1] + side_key_states = side_key_states.unsqueeze(1).repeat(reps) + side_value_states = side_value_states.unsqueeze(1).repeat(reps) + + # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones + # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head) + key_states = torch.cat([key_states, side_key_states], dim=2) + value_states = torch.cat([value_states, side_value_states], dim=2) + + # Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len + global_seq_len) + scores = torch.einsum("...qhd,...khd->...hqk", query_states, key_states) + + if mask is not None: + # We need to adjust position bias shape to be sum with mask + local_attention_mask = _get_local_attention_mask(mask, self.block_len, hidden_states.device) + # Replace masked positions with -10_000 (according to the original implementation) + local_attention_mask = torch.where(local_attention_mask > 0, 0.0, -1e10) + else: + local_attention_mask = None + + if position_bias is None: + # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len) + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, 1, self.n_heads, self.block_len, 3 * self.block_len), + device=scores.device, + dtype=scores.dtype, + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(self.block_len) + + if local_attention_mask is not None: + # (batch_size, 1, n_heads, block_len, 3 * block_len) + position_bias = position_bias + local_attention_mask.transpose(1, 2) + position_bias = position_bias.type(scores.dtype) + + # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len) + if mask is None: + mask = torch.ones(batch_size, seq_length) + # (batch_size, num_heads, seq_len, global_seq_len) + side_position_bias = self.compute_side_bias(mask, global_segment_ids) + # (batch_size, num_blocks, num_heads, block_len, global_seq_len) + side_position_bias = _split_into_blocks(side_position_bias, self.block_len, dim=-2).transpose(1, 2) + side_position_bias = side_position_bias.type(scores.dtype).to(scores.device) + # (batch_size, num_blocks, num_heads, block_len, 3 * block_len + global_seq_len) + position_bias = torch.cat([position_bias, side_position_bias], dim=-1) + + scores += position_bias + # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + attn_weights = attn_weights.type(value_states.dtype) + attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states)) + attn_output = attn_output[:, :seq_length, :] + attn_output = self.o(attn_output) + + present_key_value_state = None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5 +class LongT5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = LongT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class LongT5LayerLocalSelfAttention(nn.Module): + """Local self attention used in encoder""" + + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + output_attentions=False, + **kwargs: Any, # to accept past_key_value and use_cache kwargs + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.LocalSelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class LongT5LayerTransientGlobalSelfAttention(nn.Module): + """Transient-Global self attention used in encoder""" + + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + output_attentions=False, + **kwargs: Any, # to accept past_key_value and use_cache kwargs + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.TransientGlobalSelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5 +class LongT5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False) + self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class LongT5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + if config.is_decoder: + attention_layer = LongT5LayerSelfAttention + elif config.encoder_attention_type == "local": + attention_layer = LongT5LayerLocalSelfAttention + elif config.encoder_attention_type == "transient-global": + attention_layer = LongT5LayerTransientGlobalSelfAttention + else: + raise ValueError( + "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, " + f"but got {config.encoder_attention_type}." + ) + self.layer = nn.ModuleList() + self.layer.append(attention_layer(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(LongT5LayerCrossAttention(config)) + + self.layer.append(LongT5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + if not self.is_decoder: + logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class LongT5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LongT5Config + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + + @property + # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, LongT5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (LongT5Model, LongT5ForConditionalGeneration, LongT5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, LongT5DenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, LongT5DenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + if isinstance(module, LongT5TransientGlobalAttention): + module.global_relative_attention_bias.weight.data.normal_( + mean=0.0, std=factor * ((d_model) ** -0.5) + ) + + # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5 + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LongT5Attention, LongT5Stack)): + module.gradient_checkpointing = value + + # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the" + " pad_token_id. See LongT5 docs for more information" + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class LongT5Stack(LongT5PreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.local_radius = config.local_radius + self.block_len = self.local_radius + 1 + + self.block = nn.ModuleList( + [LongT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + + self.gradient_checkpointing = False + + # Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings + def get_input_embeddings(self): + return self.embed_tokens + + # Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used + if self.is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, inputs_embeds.device + ) + elif self.config.encoder_attention_type == "local": + extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) + else: # we need to use both local attention mask and standard extended mask for transient-global attention + extended_attention_mask = attention_mask + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +LONGT5_START_DOCSTRING = r""" + + The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long + Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo + Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising + generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different + efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LongT5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +LONGT5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so + you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5 + Training](./longt5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5 + Training](./longt5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +LONGT5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so + you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5 + Training](./longt5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top.", + LONGT5_START_DOCSTRING, +) +class LongT5Model(LongT5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config: LongT5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = LongT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = LongT5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, LongT5Model + + >>> tokenizer = T5Tokenizer.from_pretrained("google/LongT5-Local-Base") + >>> model = LongT5Model.from_pretrained("google/LongT5-Local-Base") + + >>> # Let's try a very long encoder input. + >>> input_ids = tokenizer( + ... 100 * "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING) +class LongT5ForConditionalGeneration(LongT5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"lm_head\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config: LongT5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = LongT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = LongT5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps") + >>> model = LongT5ForConditionalGeneration.from_pretrained( + ... "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps" + ... ) + + >>> # Let's try a very long input. + >>> input_ids = tokenizer( + ... "summarize: " + 100 * "studies have shown that owning a dog is good for you ", return_tensors="pt" + ... ).input_ids # Batch size 1 + + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + abstractthe aim of this article is to summarize the studies have shown that owning a dog + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare LONGT5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + LONGT5_START_DOCSTRING, +) +class LongT5EncoderModel(LongT5PreTrainedModel): + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + ] + + def __init__(self, config: LongT5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = LongT5Stack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(LONGT5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google/LongT5-Local-Base") + >>> model = LongT5EncoderModel.from_pretrained("google/LongT5-Local-Base") + >>> input_ids = tokenizer( + ... 100 * "Studies have been shown that owning a dog is good for you ", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index a54bb84712c..66dc321d26d 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -281,6 +281,13 @@ class FeaturesManager: "token-classification", onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig", ), + "longt5": supported_features_mapping( + "default", + "default-with-past", + "seq2seq-lm", + "seq2seq-lm-with-past", + onnx_config_cls="models.longt5.LongT5OnnxConfig", + ), "marian": supported_features_mapping( "default", "default-with-past", diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index bc0d43b01b9..44c3b1cf3e4 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -725,6 +725,27 @@ class FlaxGPTJPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxLongT5ForConditionalGeneration(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxLongT5Model(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxLongT5PreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxMarianModel(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2ba91c10ab3..1c3df1e4f91 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2605,6 +2605,37 @@ class LongformerSelfAttention(metaclass=DummyObject): requires_backends(self, ["torch"]) +LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class LongT5EncoderModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LongT5ForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LongT5Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LongT5PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/test_results.txt b/test_results.txt new file mode 100644 index 00000000000..b061a2f1b98 --- /dev/null +++ b/test_results.txt @@ -0,0 +1,9 @@ +background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred + + +background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred + + +background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred + + diff --git a/tests/models/longt5/__init__.py b/tests/models/longt5/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/longt5/test_modeling_flax_longt5.py b/tests/models/longt5/test_modeling_flax_longt5.py new file mode 100644 index 00000000000..9406e292d17 --- /dev/null +++ b/tests/models/longt5/test_modeling_flax_longt5.py @@ -0,0 +1,757 @@ +# coding=utf-8 +# Copyright 2022 Google LongT5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import numpy as np + +import transformers +from transformers import is_flax_available +from transformers.models.auto import get_values +from transformers.testing_utils import ( + is_pt_flax_cross_test, + require_flax, + require_sentencepiece, + require_tokenizers, + slow, +) + +from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor + + +if is_flax_available(): + import os + + # The slow tests are often failing with OOM error on GPU + # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed + # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" + + import jax + import jax.numpy as jnp + from flax.core.frozen_dict import unfreeze + from flax.traverse_util import flatten_dict + from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING, AutoTokenizer, LongT5Config + from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model + from transformers.models.longt5.modeling_flax_longt5 import ( + FlaxLongT5ForConditionalGeneration, + FlaxLongT5Model, + shift_tokens_right, + ) + + +class FlaxLongT5ModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + decoder_seq_length=9, + local_radius=5, + encoder_attention_type="local", + global_block_size=3, + # For common tests + is_training=True, + use_attention_mask=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + decoder_start_token_id=0, + scope=None, + decoder_layers=None, + ): + + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + self.decoder_seq_length = decoder_seq_length + self.local_radius = local_radius + self.block_len = local_radius + 1 + self.encoder_attention_type = encoder_attention_type + self.global_block_size = global_block_size + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.scope = None + self.decoder_layers = decoder_layers + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + config = LongT5Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + local_radius=self.local_radius, + encoder_attention_type=self.encoder_attention_type, + global_block_size=self.global_block_size, + ) + + return ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + ): + model = FlaxLongT5Model(config=config) + result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + decoder_output = result.last_hidden_state + encoder_output = result.encoder_last_hidden_state + + self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size)) + self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.decoder_seq_length, self.hidden_size)) + + def check_use_cache_forward_with_attn_mask( + self, + model_class_name, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + ): + max_decoder_length = 20 + model = model_class_name(config) + + encoder_outputs = model.encode(input_ids) + + # prevent fully zero'd out attention mask + decoder_attention_mask = jnp.ones_like(decoder_attention_mask) + + decoder_attention_mask_cache = jnp.concatenate( + [ + decoder_attention_mask, + jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])), + ], + axis=-1, + ) + + past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) + + outputs_cache = model.decode( + decoder_input_ids[:, :-1], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask_cache, + past_key_values=past_key_values, + ) + outputs_cache_next = model.decode( + decoder_input_ids[:, -1:], + encoder_outputs, + past_key_values=outputs_cache.past_key_values, + decoder_attention_mask=decoder_attention_mask_cache, + ) + + outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask) + + diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) + self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + return config, inputs_dict + + +@require_flax +class FlaxLongT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase): + + all_model_classes = (FlaxLongT5Model, FlaxLongT5ForConditionalGeneration) if is_flax_available() else () + all_generative_model_classes = (FlaxLongT5ForConditionalGeneration,) if is_flax_available() else () + is_encoder_decoder = True + + def setUp(self): + self.model_tester = FlaxLongT5ModelTester(self) + self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_v1_1(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + # check that gated gelu feed forward and different word embeddings work + config = config_and_inputs[0] + config.tie_word_embeddings = False + config.feed_forward_proj = "gated-gelu" + self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) + + def test_use_cache_forward_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs) + + def test_encode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def encode_jitted(input_ids, attention_mask=None, **kwargs): + return model.encode(input_ids=input_ids, attention_mask=attention_mask) + + with self.subTest("JIT Enabled"): + jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + def test_decode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + model = model_class(config) + encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"]) + + prepared_inputs_dict = { + "decoder_input_ids": inputs_dict["decoder_input_ids"], + "decoder_attention_mask": inputs_dict["decoder_attention_mask"], + "encoder_outputs": encoder_outputs, + } + + @jax.jit + def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs): + return model.decode( + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + ) + + with self.subTest("JIT Enabled"): + jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = decode_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + def test_shift_right(self): + decoder_start_token_id = 0 + pad_token_id = 1 + labels = np.arange(2, 102).reshape(5, 20) + labels[:2, 15:] = -100 + + decoder_input_ids = shift_tokens_right(labels, pad_token_id, decoder_start_token_id) + np_decoder_input_ids = np.array(decoder_input_ids) + + padded_slice = np_decoder_input_ids[:2, (15 + 1) :] + self.assertTrue((padded_slice == 1).all()) + + not_padded_slice = np_decoder_input_ids[2:, 1:] + rolled_labels = np.roll(labels[2:], 1)[:, 1:] + self.assertTrue((not_padded_slice == rolled_labels).all()) + self.assertTrue((np_decoder_input_ids[:, 0] == 0).all()) + + # overwrite since special base model prefix is used + def test_save_load_from_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname) + + base_param_from_head = flatten_dict(unfreeze(head_model.params)) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + def test_save_load_to_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_length = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + block_len = getattr(self.model_tester, "block_len", None) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len], + ) + out_len = len(outputs) + + if self.is_encoder_decoder: + correct_outlen = 5 + + # Question Answering model returns start_logits and end_logits + if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING): + correct_outlen += 1 # start_logits and end_logits instead of only 1 output + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + encoder_key_length, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len], + ) + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + # save pt model + pt_model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname, from_pt=True) + + base_param_from_head = flatten_dict(unfreeze(head_model.params)) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + model.params = model.to_bf16(model.params) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + +class FlaxLongT5TGlobalModelTest(FlaxLongT5ModelTest): + def setUp(self): + self.model_tester = FlaxLongT5ModelTester(self, encoder_attention_type="transient-global") + self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_length = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + block_len = getattr(self.model_tester, "block_len", None) + global_block_size = getattr(self.model_tester, "global_block_size", None) + global_seq_len = encoder_seq_length // global_block_size + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len], + ) + out_len = len(outputs) + + if self.is_encoder_decoder: + correct_outlen = 5 + + # Question Answering model returns start_logits and end_logits + if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING): + correct_outlen += 1 # start_logits and end_logits instead of only 1 output + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + encoder_key_length, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len], + ) + + +@require_sentencepiece +@require_tokenizers +@require_flax +class FlaxLongT5ModelIntegrationTests(unittest.TestCase): + model_path = "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps" + + def expected_summary(self): + return [ + "background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in" + " developing world . it provides an excellent resolution for visualization of the coronary arteries for" + " catheter - based or operating interventions . although the association of this technique with major" + " complications such as mortality is highly uncommon , it is frequently associated with various cardiac" + " and noncardiac complications . computed tomography coronary angiography is a promising technique for the" + " evaluation of cad noninvasively . it assesses disease within the coronary artery and provides" + " qualitative and quantitative information about nonobstructive atherosclerotic plaque" + ] + + @slow + def test_summarization(self): + model = FlaxLongT5ForConditionalGeneration.from_pretrained(self.model_path) + tok = AutoTokenizer.from_pretrained(self.model_path) + + ARTICLE = """coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . \n it provides an excellent resolution for visualization of the coronary arteries for catheter - based or operating interventions . \n + although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications . computed tomography ( ct ) coronary angiography is + a promising technique for the evaluation of cad noninvasively . \n it assesses disease within the coronary artery and provides qualitative and quantitative information about nonobstructive atherosclerotic plaque burden within the vessel + wall . \n thus , ct angiography - based disease evaluation may provide clinically more significant information than conventional angiography . the introduction of multi - slice computed tomography ( msct ) technology such as 64-slice , 12 + 8-slice , 256-slice , and now 320-slice msct has produced a high diagnostic accuracy of ct coronary angiography . \n it has consistently showed to have a very high negative predictive value ( well above 90% ) in ruling out patients with s + ignificant cad defined as coronary luminal stenosis of > 50% . \n the american college of cardiology / american heart association recommends that coronary angiography should be performed before valve surgery in men aged > 40 years , women + aged > 35 years with coronary risk factors and in postmenopausal women . \n the prevalence of cad in patients undergoing valve replacement is 2040% in developed countries . in the previous studies , \n the incidence of angiographically p + roven cad in acquired valvular diseases has been shown to vary widely from 9% to 41% . in aortic stenosis , \n we aimed to report the diagnostic performance of 128-slice ct coronary angiography in 50 patients undergoing for major noncoron + ary cardiac surgery referred for diagnostic invasive coronary angiography to assess the extent and severity of coronary stenosis . \n during january 2013 to december 2014 , we enrolled fifty major noncoronary cardiac surgery patients sche + duled for invasive coronary angiography who fulfilled the following inclusion criteria of age 40 years , having low or intermediate probability of cad , left ventricular ejection fraction ( lvef ) > 35% , and patient giving informed conse + nt for undergoing msct and conventional coronary angiography . \n those having any contraindication for contrast injection , lvef < 35% , high pretest probability of cad , and hemodynamic instability were excluded from the study . \n pati + ents with heart rates of > 70 bpm received ( unless they had known overt heart failure or electrocardiogram ( ecg ) atrioventricular conduction abnormalities ) a single oral dose of 100 mg metoprolol 45 min before the scan . \n patients w + ith heart rates of > 80 bpm received an additional oral dose of metoprolol if not contraindicated . \n all patients were scanned with a 128-slice ct scanner ( siemens , somatom definition as ) equipped with a new feature in msct technolog + y , so - called z - axis flying - focus technology . \n the central 32 detector rows acquire 0.6-mm slices , and the flying - focus spot switches back and forth between 2 z positions between each reading . \n two slices per detector row a + re acquired , which results in a higher oversampling rate in the z - axis , thereby reducing artifacts related to the spiral acquisition and improving spatial resolution down to 0.4 mm . \n a bolus of 6580 ml contrast material ( omnipaque + ) was injected through an arm vein at a flow rate of 5 ml / s . \n a bolus tracking technique was used to synchronize the arrival of contrast in the coronary arteries with the initiation of the scan . to monitor the arrival of contrast m + aterial , \n axial scans were obtained at the level of the ascending aorta with a delay of 10 s after the start of the contrast injection . \n the scan was automatically started when a threshold of 150 hounsfield units was reached in a re + gion of interest positioned in the ascending aorta . \n images were reconstructed with ecg gating to obtain optimal , motion - free image quality . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a s + ingle observer unaware of the multi - slice ct results identified coronary lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiograp + hy . \n lesions were classified as having nonsignificant disease ( luminal irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean + lumen diameter reduction was 50% using a validated quantitative coronary angiography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiograp + hy . \n total calcium scores of all patients were calculated with dedicated software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of th + e number , areas , and peak hounsfield units of the detected calcified lesions . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were + used to identify coronary lesions and ( curved ) multiplanar reconstructions to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the di + agnostic performance of ct coronary angiography for the detection of significant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and + positive and negative likelihood ratios with the corresponding exact 95% of confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease p + er vessel ) , and patient by patient ( no or any disease per patient ) . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a single observer unaware of the multi - slice ct results identified coronary + lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiography . \n lesions were classified as having nonsignificant disease ( luminal + irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean lumen diameter reduction was 50% using a validated quantitative coronary an + giography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiography . \n total calcium scores of all patients were calculated with dedicated + software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of the number , areas , and peak hounsfield units of the detected calcified lesi + ons . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were used to identify coronary lesions and ( curved ) multiplanar reconstruction + s to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the diagnostic performance of ct coronary angiography for the detection of signif + icant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and positive and negative likelihood ratios with the corresponding exact 95% of + confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease per vessel ) , and patient by patient ( no or any disease per patient ) . \n + in this study , 29 ( 58% ) subjects were female , and 21 ( 42% ) were male showing an average age of 50.36 8.39 years . \n of fifty patients 24 ( 48% ) , 13 ( 26% ) , eight ( 16% ) , and five ( 10% ) underwent mitral valve replacement , + double valve replacement ( dvr ) , aortic valve replacement , and other surgeries , respectively . \n high distribution of cad risk factors such as hypertension ( 24% ) , smoking ( 22% ) , and dyslipidemia ( 18% ) was observed in the stu + dy group . \n the mean creatinine level was 0.766 0.17 and average dye used in conventional angiography was 48.5 26.6 whereas for ct angiography it was 72.8 6.32 . \n average radiation dose in conventional coronary angiography and msct + coronary angiography was 5.2 msv and 9.2 msv , respectively . \n the majority of the patients had sinus rhythm ( 68% ) , whereas atrial fibrillation was found in 32% of the subjects . \n patients included in the study had low to intermed + iate probability of cad . in this study , three patients had complications after conventional angiography . \n complications were of local site hematoma , acute kidney injury managed conservatively , and acute heart failure . \n a patient + who developed hematoma was obese female patients with body mass index > 30 kg / m . \n the patient suffered from pseudoaneurysm , had hospitalized for 9 days , which leads to increased morbidity and cost of hospital stay . \n the diagnos + tic accuracy of ct coronary angiography was evaluated regarding true positive , true negative values and is presented in table 1 . the overall sensitivity and \n specificity of ct angiography technique was 100% ( 95% ci : 39.76%100% ) and + 91.30% ( 95% ci : 79.21%97.58% ) , respectively [ table 2 ] . \n the positive predictive value ( 50% ; 95% ci : 15.70%84.30% ) and negative predictive value ( 100% ; 95% ci : 91.59%100% ) of ct angiography were also fairly high in these + patients . \n recent reports from multiple studies demonstrated that recent - generation msct scanners showed promise for noninvasive detection of coronary stenosis however , until now no studies were found regarding the clinical efficacy + or prognostic value of 128-slice ct coronary angiography versus conventional invasive coronary angiography in the diagnosis of patients planned for major noncoronary surgeries such as dvr , bentall , atrial septal defect closure , etc . + in our study , we reported 8% cad prevalence in patients planned for major noncoronary cardiac surgery . \n we performed conventional and msct coronary angiography in all patients and the results showed that ct coronary angiography with i + nvasive coronary angiography as the reference standard had a considerably high sensitivity ( 100% ) and specificity ( 95.65% ) . \n the health economic model using invasive coronary angiography as the reference standard showed that at a p + retest probability of cad of 70% or lower , ct coronary angiography resulted in lower cost per patient with a true positive diagnosis . at a pretest probability of cad of 70% or higher , invasive coronary angiography was associated with a + lower cost per patient with a true positive diagnosis . in our study population , \n two patients developed local site complications in the form of hematoma and pseudoaneurysm after conventional angiography . \n hence , msct coronary ang + iography will be more favorable in female obese patients with intermediate likelihood of cad . \n hence , msct coronary angiography will be cost - effective in patients of valvular heart diseases . \n however , ct angiography suffers from + a drawback that average amount of dye used in msct coronary angiography were 72.8 6.32 ml which is higher than average amount of dye required for conventional angiography ( 48.6 26.6 ml ) . \n hence , the use of ct coronary angiography + could not be used in patients with known renal dysfunction , where reduction of contrast dye load is highly advocated . \n our results show that 128-slice ct coronary angiography is a reliable technique to detect coronary stenosis in pat + ients planned for noncoronary cardiac surgery . \n although there has been important technological progress in the development of ct coronary angiography , its clinical application remains limited . \n a study wth large numbers of patient + s is required for the recommendation of only ct coronary angiography for the coronary evaluation in major non - cardiac surgeries . \n mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , guja + rat , india ) . \n u.n . mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , gujarat , india ) . \n """ + + dct = tok( + [ARTICLE], + max_length=1024, + padding="max_length", + truncation=True, + return_tensors="np", + ) + + hypotheses_batch = model.generate( + **dct, + num_beams=4, + length_penalty=2.0, + max_length=142, + min_length=56, + do_sample=False, + early_stopping=True, + ).sequences + + decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertListEqual( + self.expected_summary(), + decoded, + ) diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py new file mode 100644 index 00000000000..e6716b7fbcc --- /dev/null +++ b/tests/models/longt5/test_modeling_longt5.py @@ -0,0 +1,1315 @@ +# coding=utf-8 +# Copyright 2022 Google LongT5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +import tempfile +import unittest + +from transformers import LongT5Config, is_torch_available +from transformers.models.auto import get_values +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device +from transformers.utils import cached_property + +from ...generation.test_generation_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + AutoTokenizer, + LongT5EncoderModel, + LongT5ForConditionalGeneration, + LongT5Model, + ) + from transformers.models.longt5.modeling_longt5 import LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST + + +class LongT5ModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + decoder_seq_length=9, + local_radius=5, + encoder_attention_type="local", + global_block_size=3, + # For common tests + is_training=True, + use_attention_mask=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + decoder_start_token_id=0, + scope=None, + decoder_layers=None, + large_model_config_path="google/LongT5-Local-Large", + ): + + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + self.decoder_seq_length = decoder_seq_length + self.local_radius = local_radius + self.block_len = local_radius + 1 + self.encoder_attention_type = encoder_attention_type + self.global_block_size = global_block_size + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.scope = None + self.decoder_layers = decoder_layers + self.large_model_config_path = large_model_config_path + + def get_large_model_config(self): + return LongT5Config.from_pretrained(self.large_model_config_path) + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = self.get_config() + + return ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + def get_pipeline_config(self): + return LongT5Config( + vocab_size=166, # longt5 forces 100 extra tokens + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + local_radius=self.local_radius, + encoder_attention_type=self.encoder_attention_type, + global_block_size=self.global_block_size, + ) + + def get_config(self): + return LongT5Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + local_radius=self.local_radius, + encoder_attention_type=self.encoder_attention_type, + global_block_size=self.global_block_size, + ) + + def check_prepare_lm_labels_via_shift_left( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = LongT5Model(config=config) + model.to(torch_device) + model.eval() + + # make sure that lm_labels are correctly padded from the right + lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id) + + # add casaul pad token mask + triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not() + lm_labels.masked_fill_(triangular_mask, self.pad_token_id) + decoder_input_ids = model._shift_right(lm_labels) + + for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)): + # first item + self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id) + if i < decoder_input_ids_slice.shape[-1]: + if i < decoder_input_ids.shape[-1] - 1: + # items before diagonal + self.parent.assertListEqual( + decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist() + ) + # pad items after diagonal + if i < decoder_input_ids.shape[-1] - 2: + self.parent.assertListEqual( + decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist() + ) + else: + # all items after square + self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist()) + + def create_and_check_model( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = LongT5Model(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + decoder_output = result.last_hidden_state + decoder_past = result.past_key_values + encoder_output = result.encoder_last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) + self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) + # There should be `num_layers` key value embeddings stored in decoder_past + self.parent.assertEqual(len(decoder_past), config.num_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple + self.parent.assertEqual(len(decoder_past[0]), 4) + + def create_and_check_with_lm_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = LongT5ForConditionalGeneration(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = LongT5Model(config=config).get_decoder().to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = LongT5Model(config=config).get_decoder() + model.to(torch_device) + model.eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = LongT5Model(config=config).get_decoder().to(torch_device).eval() + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_generate_with_past_key_values( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = LongT5ForConditionalGeneration(config=config).to(torch_device).eval() + torch.manual_seed(0) + output_without_past_cache = model.generate( + input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False + ) + torch.manual_seed(0) + output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + + def create_and_check_encoder_decoder_shared_weights( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + for model_class in [LongT5Model, LongT5ForConditionalGeneration]: + torch.manual_seed(0) + model = model_class(config=config).to(torch_device).eval() + # load state dict copies weights but does not tie them + model.encoder.load_state_dict(model.decoder.state_dict(), strict=False) + + torch.manual_seed(0) + tied_config = copy.deepcopy(config) + tied_config.tie_encoder_decoder = True + tied_model = model_class(config=tied_config).to(torch_device).eval() + + model_result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that models has less parameters + self.parent.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + # check that outputs are equal + self.parent.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4 + ) + ) + + # check that outputs after saving and loading are equal + with tempfile.TemporaryDirectory() as tmpdirname: + tied_model.save_pretrained(tmpdirname) + tied_model = model_class.from_pretrained(tmpdirname) + tied_model.to(torch_device) + tied_model.eval() + + # check that models has less parameters + self.parent.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that outputs are equal + self.parent.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], + tied_model_result[0][0, :, random_slice_idx], + atol=1e-4, + ) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "use_cache": False, + } + return config, inputs_dict + + +@require_torch +class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + + all_model_classes = (LongT5Model, LongT5ForConditionalGeneration) if is_torch_available() else () + all_generative_model_classes = (LongT5ForConditionalGeneration,) if is_torch_available() else () + fx_compatible = False + all_parallelizable_model_classes = (LongT5Model, LongT5ForConditionalGeneration) if is_torch_available() else () + test_pruning = False + test_torchscript = True + test_resize_embeddings = True + test_model_parallel = True + is_encoder_decoder = True + + def setUp(self): + self.model_tester = LongT5ModelTester(self) + self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_shift_right(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_with_lm_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_lm_head(*config_and_inputs) + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_past_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_decoder_model_past_with_3d_attn_mask(self): + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = self.model_tester.prepare_config_and_inputs() + + attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length], + vocab_size=2, + ) + decoder_attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length], + vocab_size=2, + ) + + self.model_tester.create_and_check_decoder_model_attention_mask_past( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_generate_with_past_key_values(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs) + + def test_encoder_decoder_shared_weights(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = LongT5Model.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_export_to_onnx(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + model = LongT5Model(config_and_inputs[0]).to(torch_device) + with tempfile.TemporaryDirectory() as tmpdirname: + torch.onnx.export( + model, + (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), + f"{tmpdirname}/longt5_test.onnx", + export_params=True, + opset_version=13, + input_names=["input_ids", "decoder_input_ids"], + ) + + def test_generate_with_head_masking(self): + attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + max_length = config_and_inputs[1].shape[-1] + 3 + model = LongT5ForConditionalGeneration(config).eval() + model.to(torch_device) + + head_masking = { + "head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device), + "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), + "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), + } + + for attn_name, (name, mask) in zip(attention_names, head_masking.items()): + head_masks = {name: mask} + # Explicitly pass decoder_head_mask as it is required from LONGT5 model when head_mask specified + if name == "head_mask": + head_masks["decoder_head_mask"] = torch.ones( + config.num_decoder_layers, config.num_heads, device=torch_device + ) + + out = model.generate( + config_and_inputs[1], + num_beams=1, + max_length=max_length, + output_attentions=True, + return_dict_in_generate=True, + **head_masks, + ) + # We check the state of decoder_attentions and cross_attentions just from the last step + attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] + self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) + + def test_attention_outputs(self): + if not self.has_attentions: + pass + + else: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + block_len = getattr(self.model_tester, "block_len", None) + + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len], + ) + out_len = len(outputs) + + if self.is_encoder_decoder: + correct_outlen = 5 + + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + # Question Answering model returns start_logits and end_logits + if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): + correct_outlen += 1 # start_logits and end_logits instead of only 1 output + if "past_key_values" in outputs: + correct_outlen += 1 # past_key_values have been returned + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + encoder_key_length, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len], + ) + + def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): + block_len = getattr(self.model_tester, "block_len", None) + encoder_expected_shape = (batch_size, 1, config.num_attention_heads, block_len, 3 * block_len) + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [layer_attentions.shape for layer_attentions in attentions], + [encoder_expected_shape] * len(attentions), + ) + + +@require_torch +class LongT5TGlobalModelTest(LongT5ModelTest): + def setUp(self): + self.model_tester = LongT5ModelTester( + self, encoder_attention_type="transient-global", large_model_config_path="google/LongT5-TGlobal-Large" + ) + self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37) + + def test_attention_outputs(self): + if not self.has_attentions: + pass + + else: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + block_len = getattr(self.model_tester, "block_len", None) + global_block_size = getattr(self.model_tester, "global_block_size", None) + global_seq_len = encoder_seq_length // global_block_size + + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len], + ) + out_len = len(outputs) + + if self.is_encoder_decoder: + correct_outlen = 5 + + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + # Question Answering model returns start_logits and end_logits + if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): + correct_outlen += 1 # start_logits and end_logits instead of only 1 output + if "past_key_values" in outputs: + correct_outlen += 1 # past_key_values have been returned + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + encoder_key_length, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len], + ) + + def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): + block_len = getattr(self.model_tester, "block_len", None) + global_block_size = getattr(self.model_tester, "global_block_size", None) + global_seq_length = seq_length // global_block_size + encoder_expected_shape = ( + batch_size, + 1, + config.num_attention_heads, + block_len, + 3 * block_len + global_seq_length, + ) + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [layer_attentions.shape for layer_attentions in attentions], + [encoder_expected_shape] * len(attentions), + ) + + +class LongT5EncoderOnlyModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + local_radius=5, + encoder_attention_type="local", + global_block_size=3, + # For common tests + use_attention_mask=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + is_training=False, + dropout_rate=0.1, + initializer_factor=0.002, + is_encoder_decoder=False, + eos_token_id=1, + pad_token_id=0, + scope=None, + large_model_config_path="google/LongT5-Local-Large", + ): + + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + self.local_radius = local_radius + self.block_len = local_radius + 1 + self.encoder_attention_type = encoder_attention_type + self.global_block_size = global_block_size + # For common tests + self.seq_length = self.encoder_seq_length + self.use_attention_mask = use_attention_mask + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.is_encoder_decoder = is_encoder_decoder + self.scope = None + self.is_training = is_training + self.large_model_config_path = large_model_config_path + + def get_large_model_config(self): + return LongT5Config.from_pretrained(self.large_model_config_path) + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + + config = LongT5Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + local_radius=self.local_radius, + encoder_attention_type=self.encoder_attention_type, + global_block_size=self.global_block_size, + ) + + return ( + config, + input_ids, + attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + ): + model = LongT5EncoderModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + attention_mask=attention_mask, + ) + result = model(input_ids=input_ids) + encoder_output = result.last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +class LongT5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (LongT5EncoderModel,) if is_torch_available() else () + test_pruning = False + test_torchscript = True + test_resize_embeddings = False + test_model_parallel = True + all_parallelizable_model_classes = (LongT5EncoderModel,) if is_torch_available() else () + + def setUp(self): + self.model_tester = LongT5EncoderOnlyModelTester(self) + self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_attention_outputs(self): + if not self.has_attentions: + pass + + else: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + block_len = getattr(self.model_tester, "block_len", 4) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len], + ) + + +class LongT5EncoderOnlyTGlobalModelTest(LongT5EncoderOnlyModelTest): + def setUp(self): + self.model_tester = LongT5EncoderOnlyModelTester( + self, encoder_attention_type="transient-global", large_model_config_path="google/LongT5-TGlobal-Large" + ) + self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37) + + def test_attention_outputs(self): + if not self.has_attentions: + pass + + else: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + block_len = getattr(self.model_tester, "block_len", None) + seq_len = getattr(self.model_tester, "seq_length", None) + global_block_size = getattr(self.model_tester, "global_block_size", 4) + global_seq_len = seq_len // global_block_size + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len], + ) + + +def use_task_specific_params(model, task): + model.config.update(model.config.task_specific_params[task]) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class LongT5ModelIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return LongT5ForConditionalGeneration.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps").to( + torch_device + ) + + @cached_property + def tokenizer(self): + return AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps") + + def expected_summary(self): + return [ + "background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in" + " developing world . it provides an excellent resolution for visualization of the coronaryarteries for" + " catheter - based or operating interventions . although the association of this technique with major" + " complications such as mortality is highly uncommon , it is frequently associated with various cardiac" + " and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the" + " diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for" + " major noncoron ary cardiac surgery referred" + ] + + @slow + def test_summarization(self): + model = self.model + tok = self.tokenizer + + ARTICLE = """coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . \n it provides an excellent resolution for visualization of the coronary arteries for catheter - based or operating interventions . \n + although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications . computed tomography ( ct ) coronary angiography is + a promising technique for the evaluation of cad noninvasively . \n it assesses disease within the coronary artery and provides qualitative and quantitative information about nonobstructive atherosclerotic plaque burden within the vessel + wall . \n thus , ct angiography - based disease evaluation may provide clinically more significant information than conventional angiography . the introduction of multi - slice computed tomography ( msct ) technology such as 64-slice , 12 + 8-slice , 256-slice , and now 320-slice msct has produced a high diagnostic accuracy of ct coronary angiography . \n it has consistently showed to have a very high negative predictive value ( well above 90% ) in ruling out patients with s + ignificant cad defined as coronary luminal stenosis of > 50% . \n the american college of cardiology / american heart association recommends that coronary angiography should be performed before valve surgery in men aged > 40 years , women + aged > 35 years with coronary risk factors and in postmenopausal women . \n the prevalence of cad in patients undergoing valve replacement is 2040% in developed countries . in the previous studies , \n the incidence of angiographically p + roven cad in acquired valvular diseases has been shown to vary widely from 9% to 41% . in aortic stenosis , \n we aimed to report the diagnostic performance of 128-slice ct coronary angiography in 50 patients undergoing for major noncoron + ary cardiac surgery referred for diagnostic invasive coronary angiography to assess the extent and severity of coronary stenosis . \n during january 2013 to december 2014 , we enrolled fifty major noncoronary cardiac surgery patients sche + duled for invasive coronary angiography who fulfilled the following inclusion criteria of age 40 years , having low or intermediate probability of cad , left ventricular ejection fraction ( lvef ) > 35% , and patient giving informed conse + nt for undergoing msct and conventional coronary angiography . \n those having any contraindication for contrast injection , lvef < 35% , high pretest probability of cad , and hemodynamic instability were excluded from the study . \n pati + ents with heart rates of > 70 bpm received ( unless they had known overt heart failure or electrocardiogram ( ecg ) atrioventricular conduction abnormalities ) a single oral dose of 100 mg metoprolol 45 min before the scan . \n patients w + ith heart rates of > 80 bpm received an additional oral dose of metoprolol if not contraindicated . \n all patients were scanned with a 128-slice ct scanner ( siemens , somatom definition as ) equipped with a new feature in msct technolog + y , so - called z - axis flying - focus technology . \n the central 32 detector rows acquire 0.6-mm slices , and the flying - focus spot switches back and forth between 2 z positions between each reading . \n two slices per detector row a + re acquired , which results in a higher oversampling rate in the z - axis , thereby reducing artifacts related to the spiral acquisition and improving spatial resolution down to 0.4 mm . \n a bolus of 6580 ml contrast material ( omnipaque + ) was injected through an arm vein at a flow rate of 5 ml / s . \n a bolus tracking technique was used to synchronize the arrival of contrast in the coronary arteries with the initiation of the scan . to monitor the arrival of contrast m + aterial , \n axial scans were obtained at the level of the ascending aorta with a delay of 10 s after the start of the contrast injection . \n the scan was automatically started when a threshold of 150 hounsfield units was reached in a re + gion of interest positioned in the ascending aorta . \n images were reconstructed with ecg gating to obtain optimal , motion - free image quality . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a s + ingle observer unaware of the multi - slice ct results identified coronary lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiograp + hy . \n lesions were classified as having nonsignificant disease ( luminal irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean + lumen diameter reduction was 50% using a validated quantitative coronary angiography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiograp + hy . \n total calcium scores of all patients were calculated with dedicated software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of th + e number , areas , and peak hounsfield units of the detected calcified lesions . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were + used to identify coronary lesions and ( curved ) multiplanar reconstructions to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the di + agnostic performance of ct coronary angiography for the detection of significant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and + positive and negative likelihood ratios with the corresponding exact 95% of confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease p + er vessel ) , and patient by patient ( no or any disease per patient ) . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a single observer unaware of the multi - slice ct results identified coronary + lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiography . \n lesions were classified as having nonsignificant disease ( luminal + irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean lumen diameter reduction was 50% using a validated quantitative coronary an + giography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiography . \n total calcium scores of all patients were calculated with dedicated + software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of the number , areas , and peak hounsfield units of the detected calcified lesi + ons . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were used to identify coronary lesions and ( curved ) multiplanar reconstruction + s to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the diagnostic performance of ct coronary angiography for the detection of signif + icant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and positive and negative likelihood ratios with the corresponding exact 95% of + confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease per vessel ) , and patient by patient ( no or any disease per patient ) . \n + in this study , 29 ( 58% ) subjects were female , and 21 ( 42% ) were male showing an average age of 50.36 8.39 years . \n of fifty patients 24 ( 48% ) , 13 ( 26% ) , eight ( 16% ) , and five ( 10% ) underwent mitral valve replacement , + double valve replacement ( dvr ) , aortic valve replacement , and other surgeries , respectively . \n high distribution of cad risk factors such as hypertension ( 24% ) , smoking ( 22% ) , and dyslipidemia ( 18% ) was observed in the stu + dy group . \n the mean creatinine level was 0.766 0.17 and average dye used in conventional angiography was 48.5 26.6 whereas for ct angiography it was 72.8 6.32 . \n average radiation dose in conventional coronary angiography and msct + coronary angiography was 5.2 msv and 9.2 msv , respectively . \n the majority of the patients had sinus rhythm ( 68% ) , whereas atrial fibrillation was found in 32% of the subjects . \n patients included in the study had low to intermed + iate probability of cad . in this study , three patients had complications after conventional angiography . \n complications were of local site hematoma , acute kidney injury managed conservatively , and acute heart failure . \n a patient + who developed hematoma was obese female patients with body mass index > 30 kg / m . \n the patient suffered from pseudoaneurysm , had hospitalized for 9 days , which leads to increased morbidity and cost of hospital stay . \n the diagnos + tic accuracy of ct coronary angiography was evaluated regarding true positive , true negative values and is presented in table 1 . the overall sensitivity and \n specificity of ct angiography technique was 100% ( 95% ci : 39.76%100% ) and + 91.30% ( 95% ci : 79.21%97.58% ) , respectively [ table 2 ] . \n the positive predictive value ( 50% ; 95% ci : 15.70%84.30% ) and negative predictive value ( 100% ; 95% ci : 91.59%100% ) of ct angiography were also fairly high in these + patients . \n recent reports from multiple studies demonstrated that recent - generation msct scanners showed promise for noninvasive detection of coronary stenosis however , until now no studies were found regarding the clinical efficacy + or prognostic value of 128-slice ct coronary angiography versus conventional invasive coronary angiography in the diagnosis of patients planned for major noncoronary surgeries such as dvr , bentall , atrial septal defect closure , etc . + in our study , we reported 8% cad prevalence in patients planned for major noncoronary cardiac surgery . \n we performed conventional and msct coronary angiography in all patients and the results showed that ct coronary angiography with i + nvasive coronary angiography as the reference standard had a considerably high sensitivity ( 100% ) and specificity ( 95.65% ) . \n the health economic model using invasive coronary angiography as the reference standard showed that at a p + retest probability of cad of 70% or lower , ct coronary angiography resulted in lower cost per patient with a true positive diagnosis . at a pretest probability of cad of 70% or higher , invasive coronary angiography was associated with a + lower cost per patient with a true positive diagnosis . in our study population , \n two patients developed local site complications in the form of hematoma and pseudoaneurysm after conventional angiography . \n hence , msct coronary ang + iography will be more favorable in female obese patients with intermediate likelihood of cad . \n hence , msct coronary angiography will be cost - effective in patients of valvular heart diseases . \n however , ct angiography suffers from + a drawback that average amount of dye used in msct coronary angiography were 72.8 6.32 ml which is higher than average amount of dye required for conventional angiography ( 48.6 26.6 ml ) . \n hence , the use of ct coronary angiography + could not be used in patients with known renal dysfunction , where reduction of contrast dye load is highly advocated . \n our results show that 128-slice ct coronary angiography is a reliable technique to detect coronary stenosis in pat + ients planned for noncoronary cardiac surgery . \n although there has been important technological progress in the development of ct coronary angiography , its clinical application remains limited . \n a study wth large numbers of patient + s is required for the recommendation of only ct coronary angiography for the coronary evaluation in major non - cardiac surgeries . \n mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , guja + rat , india ) . \n u.n . mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , gujarat , india ) . \n """ + + dct = tok( + [ARTICLE], + max_length=1024, + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(torch_device) + + hypotheses_batch = model.generate( + **dct, + num_beams=4, + length_penalty=2.0, + max_length=142, + min_length=56, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + + decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertListEqual( + self.expected_summary(), + decoded, + ) + + @slow + def test_inference_hidden_states(self): + model = self.model + + input_ids = torch.tensor( + [[100, 19, 3, 9, 7142, 1200, 145, 8, 1252, 14145, 2034, 812, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + dtype=torch.long, + device=torch_device, + ) + decoder_input_ids = torch.tensor( + [[100, 19, 3, 9, 7142, 1200, 145, 8, 1252, 14145, 2034, 812, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + dtype=torch.long, + device=torch_device, + ) + attention_mask = torch.tensor( + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + dtype=torch.long, + device=torch_device, + ) + + output = model( + input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=True + ) + + # check if encoder_outputs match + expected_output_slice = torch.tensor([0.0629, -0.1294, -0.0089, 0.0772, 0.0663], device=torch_device) + self.assertTrue(torch.allclose(output.encoder_hidden_states[-1][0, 0, :5], expected_output_slice, atol=1e-4)) + + # check if logits match + expected_output_slice = torch.tensor([5.5231, 6.1058, 3.1766, 8.2391, -5.9453], device=torch_device) + self.assertTrue(torch.allclose(output.logits[0, 0, :5], expected_output_slice, atol=1e-4)) diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 83ba77b4917..ee7fba96597 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -213,6 +213,8 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { ("blenderbot-small", "facebook/blenderbot_small-90M"), ("blenderbot", "facebook/blenderbot-400M-distill"), ("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"), + ("longt5", "google/LongT5-Local-Base"), + ("longt5", "google/LongT5-TGlobal-Base"), } # TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations. diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py index f802e5b63d1..d797383811c 100644 --- a/tests/pipelines/test_pipelines_summarization.py +++ b/tests/pipelines/test_pipelines_summarization.py @@ -18,6 +18,7 @@ from transformers import ( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, LEDConfig, + LongT5Config, SummarizationPipeline, T5Config, pipeline, @@ -54,8 +55,8 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe ) self.assertEqual(outputs, [{"summary_text": ANY(str)}]) - if not isinstance(model.config, (T5Config, LEDConfig)): - # LED, T5 can handle it. + if not isinstance(model.config, (T5Config, LongT5Config, LEDConfig)): + # LED, T5, LongT5 can handle it. # Too long. with self.assertRaises(Exception): outputs = summarizer("This " * 1000) diff --git a/utils/check_repo.py b/utils/check_repo.py index f17a425f19a..c3060b048ae 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -36,6 +36,7 @@ PATH_TO_DOC = "docs/source/en" # Update this list with models that are supposed to be private. PRIVATE_MODELS = [ "DPRSpanPredictor", + "LongT5Stack", "RealmBertModel", "T5Stack", "TFDPRSpanPredictor", diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 1738d2e92df..462c75bc5c5 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -36,6 +36,7 @@ src/transformers/models/layoutlmv2/modeling_layoutlmv2.py src/transformers/models/layoutlmv3/modeling_layoutlmv3.py src/transformers/models/longformer/modeling_longformer.py src/transformers/models/longformer/modeling_tf_longformer.py +src/transformers/models/longt5/modeling_longt5.py src/transformers/models/marian/modeling_marian.py src/transformers/models/mbart/modeling_mbart.py src/transformers/models/mobilebert/modeling_mobilebert.py