From 860264379f2c1b86f7f0fd5bbd5cb63619b773e0 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 30 Mar 2021 19:12:30 +0530 Subject: [PATCH] GPT Neo (#10848) * lets begin * boom boom * fix out proj in attn * fix attention * fix local attention * add tokenizer * fix imports * autotokenizer * fix checkpoint name * cleanup * more clean-up * more cleanup * output attentions * fix attn mask creation * fix imports * config doc * add tests * add slow tests * quality * add conversion script * copyright * typo * another bites the dust * fix attention tests * doc * add embed init in convert function * fix copies * remove tokenizer * enable caching * address review comments * improve config and create attn layer list internally * more consistent naming * init hf config from mesh-tf config json file * remove neo tokenizer from doc * handle attention_mask in local attn layer * attn_layers => attention_layers * add tokenizer_class in config * fix docstring * raise if len of attention_layers is not same as num_layers * remove tokenizer_class from config * more consistent naming * fix doc * fix checkpoint names * fp16 compat * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- README.md | 1 + docs/source/index.rst | 57 +- docs/source/model_doc/gpt_neo.rst | 65 ++ docs/source/pretrained_models.rst | 6 + src/transformers/__init__.py | 19 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 4 + src/transformers/models/auto/modeling_auto.py | 8 +- src/transformers/models/gpt_neo/__init__.py | 70 ++ .../models/gpt_neo/configuration_gpt_neo.py | 175 ++++ .../convert_gpt_neo_mesh_tf_to_pytorch.py | 70 ++ .../models/gpt_neo/modeling_gpt_neo.py | 964 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 30 + tests/test_modeling_gpt_neo.py | 511 ++++++++++ 14 files changed, 1953 insertions(+), 28 deletions(-) create mode 100644 docs/source/model_doc/gpt_neo.rst create mode 100644 src/transformers/models/gpt_neo/__init__.py create mode 100644 src/transformers/models/gpt_neo/configuration_gpt_neo.py create mode 100644 src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py create mode 100755 src/transformers/models/gpt_neo/modeling_gpt_neo.py create mode 100644 tests/test_modeling_gpt_neo.py diff --git a/README.md b/README.md index 30a00c8c277..a643fe82530 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[Funnel Transformer](https://huggingface.co/transformers/model_doc/funnel.html)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. 1. **[GPT](https://huggingface.co/transformers/model_doc/gpt.html)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 1. **[GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. +1. **[GPT Neo](https://huggingface.co/transformers/model_doc/gpt_neo.html)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. 1. **[I-BERT](https://huggingface.co/transformers/model_doc/ibert.html)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer 1. **[LayoutLM](https://huggingface.co/transformers/model_doc/layoutlm.html)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LED](https://huggingface.co/transformers/model_doc/led.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. diff --git a/docs/source/index.rst b/docs/source/index.rst index 373012c99c0..03652a77cae 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -151,79 +151,81 @@ and conversion utilities for the following models: 22. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask Learners `__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. -23. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization +23. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo + `__ by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. +24. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization `__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer -24. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training +25. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. -25. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer +26. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -26. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document +27. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -27. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +28. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -28. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual +29. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual Machine Translation `__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. -29. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +30. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -30. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +31. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -31. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +32. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -32. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +33. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -33. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +34. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -34. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +35. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -35. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +36. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -36. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +37. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -37. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +38. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -38. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper +39. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. -39. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +40. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -40. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +41. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -41. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +42. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -42. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +43. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -43. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +44. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -44. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +45. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -45. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +46. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -46. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +47. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -47. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +48. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -48. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +49. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -280,6 +282,8 @@ TensorFlow and/or Flax. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| GPT Neo | ❌ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | LED | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -443,6 +447,7 @@ TensorFlow and/or Flax. model_doc/mt5 model_doc/gpt model_doc/gpt2 + model_doc/gpt_neo model_doc/pegasus model_doc/phobert model_doc/prophetnet diff --git a/docs/source/model_doc/gpt_neo.rst b/docs/source/model_doc/gpt_neo.rst new file mode 100644 index 00000000000..e7a3732913b --- /dev/null +++ b/docs/source/model_doc/gpt_neo.rst @@ -0,0 +1,65 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +GPT Neo +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The GPTNeo model was released in the `EleutherAI/gpt-neo `__ repository by Sid +Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. It is a GPT2 like causal language model trained on the +`Pile `__ dataset. + +The architecture is similar to GPT2 except that GPT Neo uses local attention in every other layer with a window size of +256 tokens. + +Generation +_______________________________________________________________________________________________________________________ + +The :obj:`generate()` method can be used to generate text using GPT Neo model. + +.. code-block:: + + >>> from transformers import GPTNeoForCausalLM, GPT2Tokenizer + >>> model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl") + >>> tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt_neo_xl") + + >>> prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \ + ... "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \ + ... "researchers was the fact that the unicorns spoke perfect English." + + >>> input_ids = tokenizer(unicorns, return_tensors="pt").input_ids + + >>> gen_tokens = model.generate(ids, do_sample=True, temperature=0.9, max_length=100,) + >>> gen_text = tokenizer.batch_decode(gen_tokens)[0] + + +GPTNeoConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GPTNeoConfig + :members: + + +GPTNeoModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GPTNeoModel + :members: forward + + +GPTNeoForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GPTNeoForCausalLM + :members: forward diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 4a29ebf4eea..f8bcef05867 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -139,6 +139,12 @@ For the full list, refer to `https://huggingface.co/models `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50257): + Vocabulary size of the GPT Neo model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.GPTNeoModel`. Vocabulary size of the model. + Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of + :class:`~transformers.GPTNeoModel`. + attention_types (:obj:`List`, `optional`, defaults to :obj:`[[["global", "local"], 12]]`): + The type of attention for each layer in a :obj:`List` of the following format :obj:`[[["attention_type"], + num_layerss]]` e.g. for a 24 layer model :obj:`[[["global"], 24]]` or :obj:`[[["global", "local"], 12]]` + Choose the value of ``attention_type`` from :obj:`["global", "local"]` + hidden_size (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (:obj:`int`, `optional`, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_heads (:obj:`int`, `optional`, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 8192): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_new"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + embed_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.GPTNeoModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``config.is_decoder=True``. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example:: + + >>> from transformers import GPTNeoModel, GPTNeoConfig + + >>> # Initializing a GPTNeo EleutherAI/gpt_neo_xl style configuration + >>> configuration = GPTNeoConfig() + + >>> # Initializing a model from the EleutherAI/gpt_neo_xl style configuration + >>> model = GPTNeoModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "gpt_neo" + + def __init__( + self, + vocab_size=50257, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=24, + attention_types=[[["global", "local"], 12]], + num_heads=16, + intermediate_size=None, + window_size=256, + activation_function="gelu_new", + resid_dropout=0.0, + embed_dropout=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + gradient_checkpointing=False, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + **kwargs + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_heads = num_heads + self.intermediate_size = intermediate_size + self.window_size = window_size + self.activation_function = activation_function + self.resid_dropout = resid_dropout + self.embed_dropout = embed_dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.gradient_checkpointing = gradient_checkpointing + self.use_cache = use_cache + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.attention_types = attention_types + self.attention_layers = self.expand_attention_types_params(attention_types) + + if len(self.attention_layers) != self.num_layers: + raise ValueError( + "Configuration for convolutional module is incorrect." + "It is required that `len(config.attention_layers)` == `config.num_layers`" + f"but is `len(config.attention_layers) = {len(self.attention_layers)}`," + f"`config.num_layers = {self.num_layers}`." + "`config.attention_layers` is prepared using `config.attention_types`." + "Please verify the value of `config.attention_types` argument." + ) + + @staticmethod + def expand_attention_types_params(attention_types): + attentions = [] + for item in attention_types: + for _ in range(item[1]): + attentions.extend(item[0]) + return attentions + + @property + def num_attention_heads(self): + return self.num_heads + + @property + def num_hidden_layers(self): + return self.num_layers diff --git a/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py new file mode 100644 index 00000000000..8378ad53697 --- /dev/null +++ b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert GPT Neo checkpoint.""" + + +import argparse +import json + +from transformers import GPTNeoConfig, GPTNeoForCausalLM, load_tf_weights_in_gpt_neo +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config_json = json.load(open(config_file, "r")) + config = GPTNeoConfig( + hidden_size=config_json["n_embd"], + num_layers=config_json["n_layer"], + num_heads=config_json["n_head"], + attention_types=config_json["attention_types"], + max_position_embeddings=config_json["n_ctx"], + resid_dropout=config_json["res_dropout"], + embed_dropout=config_json["embed_dropout"], + attention_dropout=config_json["attn_dropout"], + ) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = GPTNeoForCausalLM(config) + + # Load weights from tf checkpoint + load_tf_weights_in_gpt_neo(model, config, tf_checkpoint_path) + + # Save pytorch-model + print("Save PyTorch model to {}".format(pytorch_dump_path)) + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained mesh-tf model. \n" + "This specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py new file mode 100755 index 00000000000..8903e41d25f --- /dev/null +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -0,0 +1,964 @@ +# coding=utf-8 +# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch GPT Neo model. """ + + +import os +from typing import Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + CausalLMOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_gpt_neo import GPTNeoConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "GPTNeoConfig" +_TOKENIZER_FOR_DOC = "GPT2Tokenizer" + +GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "EleutherAI/gpt_neo_xl", + # See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo +] + +_CHECKPOINT_FOR_DOC = "EleutherAI/gpt_neo_xl" + + +def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt_neo_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + if "global_step" not in name and "adam" not in name: + array = tf.train.load_variable(tf_path, name) + array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy() + name = name.replace("attn/q", "attn/attention/q_proj/w") + name = name.replace("attn/k", "attn/attention/k_proj/w") + name = name.replace("attn/v", "attn/attention/v_proj/w") + name = name.replace("attn/o", "attn/attention/out_proj/w") + name = name.replace("norm_1", "ln_1") + name = name.replace("norm_2", "ln_2") + name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b") + name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w") + name = name.replace("conv1d_main/c_fc/bias", "c_fc/b") + name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w") + name = name.replace("conv1d_main/c_proj/bias", "c_proj/b") + + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name[5:] # skip "gpt2/" + name = name.split("/") + pointer = model.transformer + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + + if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]: + array = array.transpose() + + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + + # init the final linear layer using word embeddings + embs = model.transformer.wte.weight + lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False) + lin.weight = embs + model.set_output_embeddings(lin) + return model + + +class GPTNeoSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e9)) + + self.attn_dropout = nn.Dropout(config.attention_dropout) + self.resid_dropout = nn.Dropout(config.resid_dropout) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.embed_dim // self.num_heads + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): + # Keep the attention weights computation in fp32 to avoid overflow issues + q = q.to(torch.float32) + k = k.to(torch.float32) + + attn_weights = torch.matmul(q, k) + nd, ns = attn_weights.size(-2), attn_weights.size(-1) + + mask = self.bias[:, :, ns - nd : ns, :ns] + attn_weights = torch.where(mask.bool(), attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.Softmax(dim=-1)(attn_weights) + attn_weights = attn_weights.to(v.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + outputs = (torch.matmul(attn_weights, v),) + if output_attentions: + outputs += (attn_weights,) + return outputs + + def merge_heads(self, x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.num_heads, x.size(-1) // self.num_heads) + x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states + if k: + return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) + else: + return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self.split_heads(query) + key = self.split_heads(key, k=True) + value = self.split_heads(value) + + if layer_past is not None: + past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below + key = torch.cat((past_key, key), dim=-1) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key.transpose(-2, -1), value) # transpose to have same shapes + else: + present = None + + attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) + a = attn_outputs[0] + + a = self.merge_heads(a) + a = self.out_proj(a) + a = self.resid_dropout(a) + + return (a, present) + attn_outputs[1:] # a, present, (attentions) + + +class GPTNeoLocalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + + self.register_buffer("masked_bias", torch.tensor(-1e9)) + + self.attn_dropout = nn.Dropout(config.attention_dropout) + self.resid_dropout = nn.Dropout(config.resid_dropout) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.embed_dim // self.num_heads + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + self.window_size = config.window_size + + def shift(self, x, offset, pad_value=0, dim=2): + t = x.shape[1] + dims = (len(x.shape) - dim) * (0, 0) + padded_x = F.pad(x, (*dims, offset, 0), value=pad_value) + return padded_x[:, :t, ...] + + def look_around(self, x, block_length, window_size): + num_complete_blocks = window_size // block_length + + parts = [x] + for i in range(1, num_complete_blocks + 1): + parts = [self.shift(x, i)] + parts + + partial_size = window_size % block_length + if partial_size > 0: + margin = x[:, :, block_length - partial_size : block_length, ...] + parts = [self.shift(margin, num_complete_blocks + 1)] + parts + return torch.cat(parts, dim=2) + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.num_heads, x.size(-1) // self.num_heads) + x = x.view(*new_x_shape) + if k: + return x.permute(0, 1, 3, 4, 2) # (batch, chunks, head, head_features, seq_length) + else: + return x.permute(0, 1, 3, 2, 4) # (batch, chunks, head, seq_length, head_features) + + def merge_heads(self, x): + x = x.permute(0, 1, 3, 2, 4).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) + + def _split_seq_length_dim_to(self, tensors, num_blocks, block_length): + return tensors.reshape(tensors.size()[0], num_blocks, block_length, -1) + + def create_attention_mask(self, bs, seq_len, windows, block_length, attention_mask): + ticker = torch.arange(seq_len)[None, :] + b_t = ticker.reshape(1, windows, block_length) + + bq_t = b_t + bq_k = self.look_around(b_t, block_length, self.window_size) + + # compute attn mask + # this matches the original implem in mess-tensorflow + # https://github.com/tensorflow/mesh/blob/8bd599a21bad01cef1300a8735c17306ce35db6e/mesh_tensorflow/transformer/attention.py#L805 + relative_position = bq_k.unsqueeze(-2) - bq_t.unsqueeze(-1) + relative_position = relative_position.transpose(-1, -2) + + sequence_id = torch.ones(bs, seq_len) + q_seq = sequence_id.reshape(-1, windows, block_length) + m_seq = sequence_id.reshape(-1, windows, block_length) + m_seq = self.look_around(m_seq, block_length, self.window_size) + + if attention_mask is not None: + attention_mask = attention_mask.to(m_seq.device) + attention_mask = attention_mask.reshape(-1, windows, block_length) + attention_mask = self.look_around(attention_mask, block_length, self.window_size) + m_seq *= attention_mask + + visible = torch.eq(q_seq.unsqueeze(-1), m_seq.unsqueeze(-2)).transpose(-1, -2) + visible = torch.logical_and(visible, torch.gt(relative_position, -self.window_size)) + mask = torch.logical_and(visible, torch.less_equal(relative_position, 0)).transpose(-1, -2).unsqueeze(2) + return mask + + def _attn(self, q, k, v, causal_mask, head_mask=None, output_attentions=False): + # attn + + # Keep the attention weights computation in fp32 to avoid overflow issues + q = q.to(torch.float32) + k = k.to(torch.float32) + + attn_weights = torch.matmul(q, k) + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + attn_weights = nn.Softmax(dim=-1)(attn_weights) + attn_weights = attn_weights.to(v.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, v) + + outputs = (attn_output,) + if output_attentions: + outputs += (attn_weights,) + return outputs + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + query = self.q_proj(hidden_states) + + if layer_past is not None: + past = layer_past[0] + key_value_hidden_states = torch.cat([past, hidden_states], dim=1) + past_length = past.size()[1] + else: + key_value_hidden_states = hidden_states + past_length = 0 + + key = self.k_proj(key_value_hidden_states) + value = self.v_proj(key_value_hidden_states) + + # compute block length and windows + bs, seq_len = hidden_states.shape[:2] + full_seq_length = seq_len + past_length + block_length = self.window_size + while full_seq_length % block_length != 0: + block_length -= 1 + num_blocks = full_seq_length // block_length + + # create buckets + if layer_past is not None: + # we just need 1 window with block_length 1 when caching is enabled + query = self._split_seq_length_dim_to(query, 1, 1) + else: + query = self._split_seq_length_dim_to(query, num_blocks, block_length) + + key = self._split_seq_length_dim_to(key, num_blocks, block_length) + value = self._split_seq_length_dim_to(value, num_blocks, block_length) + + key = self.look_around(key, block_length, self.window_size) + value = self.look_around(value, block_length, self.window_size) + + # select key/value vectors only for the last window + if layer_past is not None: + key = key[:, -1:, ...] + value = value[:, -1:, ...] + + query = self.split_heads(query) + key = self.split_heads(key, k=True) + value = self.split_heads(value) + + mask = self.create_attention_mask(bs, full_seq_length, num_blocks, block_length, attention_mask) + if layer_past is not None: + mask = mask[:, -1:, :, -1:, :] # only take the mask for the last window + mask = mask.to(hidden_states.device) + + # attn + attn_outputs = self._attn(query, key, value, mask, head_mask, output_attentions) + attn = attn_outputs[0] + + attn = self.merge_heads(attn) + attn = attn.reshape(bs, seq_len, self.embed_dim) + + attn = self.out_proj(attn) + attn = self.resid_dropout(attn) + return (attn,) + attn_outputs[1:] + + +class GPTNeoAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.layer_id = layer_id + self.attention_layers = config.attention_layers + self.attention_type = self.attention_layers[layer_id] + + if self.attention_type == "global": + self.attention = GPTNeoSelfAttention(config) + elif self.attention_type == "local": + self.attention = GPTNeoLocalSelfAttention(config) + else: + raise NotImplementedError( + "Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: {}. Select attn layer types from ['global', 'local'] only.".format( + self.attention_layers + ) + ) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + outputs = self.attention( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # cache the hidden_states instead of key_value_states + # for local attention layer + if self.attention_type == "local": + if layer_past is None: + past = hidden_states + else: + past = torch.cat([layer_past[0], hidden_states], dim=1) + outputs = (outputs[0], (past,)) + outputs[1:] + return outputs + + +class MLP(nn.Module): + def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * hidden_size + super().__init__() + embed_dim = config.hidden_size + self.c_fc = nn.Linear(embed_dim, intermediate_size) + self.c_proj = nn.Linear(intermediate_size, embed_dim) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_dropout) + + def forward(self, x): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + return self.dropout(h2) + + +class Block(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPTNeoAttention(config, layer_id) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = MLP(inner_dim, config) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + attn_outputs = self.attn( + self.ln_1(hidden_states), + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + hidden_states + + feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) + # residual connection + hidden_states = hidden_states + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPTNeoPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTNeoConfig + load_tf_weights = load_tf_weights_in_gpt_neo + base_model_prefix = "transformer" + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear,)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +GPT_NEO_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.GPTNeoConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +GPT_NEO_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): + :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else + ``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be + passed as ``input_ids``. + + Indices can be obtained using :class:`~transformers.GPTNeoTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.num_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which + have their past given to this model should not be passed as ``input_ids`` as they have already been + computed. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + + If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see + :obj:`past_key_values`). + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare GPTNeo Model transformer outputting raw hidden-states without any specific head on top.", + GPT_NEO_START_DOCSTRING, +) +class GPTNeoModel(GPTNeoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.drop = nn.Dropout(config.embed_dropout) + self.h = nn.ModuleList([Block(config, layer_id=i) for i in range(config.num_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + self.init_weights() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + global_attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + global_attention_mask = global_attention_mask[:, None, None, :] + + # Since global_attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + global_attention_mask = global_attention_mask.to(dtype=self.dtype) # fp16 compatibility + global_attention_mask = (1.0 - global_attention_mask) * -10000.0 + else: + global_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_headss x N x N + # head_mask has shape n_layer x batch x num_headss x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + attn_type = self.config.attention_layers[i] + attn_mask = global_attention_mask if attn_type == "global" else attention_mask + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attn_mask, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attn_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@add_start_docstrings( + """ + The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT_NEO_START_DOCSTRING, +) +class GPTNeoForCausalLM(GPTNeoPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] + _keys_to_ignore_on_save = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPTNeoModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.init_weights() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to + ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index cf9109d3607..139d229a879 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1449,6 +1449,36 @@ def load_tf_weights_in_gpt2(*args, **kwargs): requires_pytorch(load_tf_weights_in_gpt2) +GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class GPTNeoForCausalLM: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class GPTNeoModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class GPTNeoPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +def load_tf_weights_in_gpt_neo(*args, **kwargs): + requires_pytorch(load_tf_weights_in_gpt_neo) + + IBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py new file mode 100644 index 00000000000..bea0ee77645 --- /dev/null +++ b/tests/test_modeling_gpt_neo.py @@ -0,0 +1,511 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch GPT Neo model. """ + + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers import ( + GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST, + GPT2Tokenizer, + GPTNeoConfig, + GPTNeoForCausalLM, + GPTNeoModel, + ) + + +class GPTNeoModelTester: + def __init__( + self, + parent, + batch_size=14, + seq_length=7, + is_training=True, + use_token_type_ids=True, + use_input_mask=True, + use_labels=True, + use_mc_token_ids=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=4, + attention_types=[[["global", "local"], 2]], + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + window_size=7, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_token_type_ids = use_token_type_ids + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.use_mc_token_ids = use_mc_token_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.window_size = window_size + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.bos_token_id = vocab_size - 1 + self.eos_token_id = vocab_size - 1 + self.pad_token_id = vocab_size - 1 + self.chunk_length = window_size + self.attention_types = attention_types + + def get_large_model_config(self): + return GPTNeoConfig.from_pretrained("gpt_neo") + + def prepare_config_and_inputs(self, gradient_checkpointing=False): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + mc_token_ids = None + if self.use_mc_token_ids: + mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = GPTNeoConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, + max_position_embeddings=self.max_position_embeddings, + use_cache=not gradient_checkpointing, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + gradient_checkpointing=gradient_checkpointing, + window_size=self.window_size, + attention_types=self.attention_types, + ) + + head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_gpt_neo_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTNeoModel(config=config) + model.to(torch_device) + model.eval() + + result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + # past_key_values is not implemented + # self.parent.assertEqual(len(result.past_key_values), config.n_layer) + + def create_and_check_gpt_neo_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTNeoModel(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids) + outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size) + + # append to next input_ids and token_type_ids + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1) + + output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"] + output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTNeoForCausalLM(config) + model.to(torch_device) + model.eval() + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTNeoForCausalLM(config) + model.to(torch_device) + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + result.loss.backward() + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "head_mask": head_mask, + } + + return config, inputs_dict + + +@require_torch +class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + + all_model_classes = (GPTNeoModel, GPTNeoForCausalLM) if is_torch_available() else () + all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else () + test_missing_keys = False + test_pruning = False + test_model_parallel = False + + # special case for DoubleHeads model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + return inputs_dict + + def setUp(self): + self.model_tester = GPTNeoModelTester(self) + self.config_tester = ConfigTester(self, config_class=GPTNeoConfig, n_embd=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_gpt_neo_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt_neo_model(*config_and_inputs) + + def test_gpt_neo_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt_neo_model_past(*config_and_inputs) + + def test_gpt_neo_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_lm_head_model(*config_and_inputs) + + def test_gpt_neo_gradient_checkpointing(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + + def _get_local_attn_seq_len_block_len_windows(self, seq_len, window_size): + block_length = window_size + while seq_len % block_length != 0: + block_length -= 1 + windows = seq_len // block_length + local_seq_len = window_size + block_length + return local_seq_len, block_length, windows + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # test global attention shape + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, seq_len], + ) + # test local attention shape + encoder_key_length = self._get_local_attn_seq_len_block_len_windows(seq_len, chunk_length)[0] + self.assertListEqual( + list(attentions[-1].shape[-3:]), + [self.model_tester.num_attention_heads, seq_len, encoder_key_length], + ) + + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + + # test global attention shape + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, seq_len], + ) + + # test local attention shape + self.assertListEqual( + list(self_attentions[-1].shape[-3:]), + [self.model_tester.num_attention_heads, seq_len, encoder_key_length], + ) + + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + for idx, iter_attentions in enumerate(attentions): + tgt_len = min_length + idx if not use_cache else 1 + src_len = min_length + idx + global_expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + tgt_len, + src_len, + ) + + local_seq_len, block_len, windows = self._get_local_attn_seq_len_block_len_windows( + src_len, config.window_size + ) + block_len = 1 if use_cache else block_len + local_expected_shape = ( + batch_size * num_beam_groups, + windows, + config.num_attention_heads, + block_len, + local_seq_len, + ) + + shapes = [layer_attention.shape for layer_attention in iter_attentions] + # every other layer is local attention layers + # so alternate between expected shapes + expected_shape = [ + global_expected_shape if i % 2 == 0 else local_expected_shape for i, _ in enumerate(iter_attentions) + ] + # check attn size + self.assertListEqual(shapes, expected_shape) + + @slow + def test_batch_generation(self): + model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl") + model.to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + + tokenizer.padding_side = "left" + + # Define PAD Token = EOS Token = 50256 + tokenizer.pad_token = tokenizer.eos_token + model.config.pad_token_id = model.config.eos_token_id + + # use different length sentences to test batching + sentences = [ + "Hello, my dog is a little", + "Today, I am", + ] + + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"].to(torch_device) + + outputs = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + ) + + inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) + output_non_padded = model.generate(input_ids=inputs_non_padded) + + num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() + inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) + output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) + + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) + padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) + + expected_output_sentence = [ + "Hello, my dog is a little bit of a kitty. She is a very sweet and loving", + "Today, I am going to talk about the best way to get a job in the", + ] + self.assertListEqual(expected_output_sentence, batch_out_sentence) + self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) + + @slow + def test_model_from_pretrained(self): + for model_name in GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = GPTNeoModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +@require_torch +class GPTNeoModelLanguageGenerationTest(unittest.TestCase): + @slow + def test_lm_generate_gpt_neo(self): + for checkpointing in [True, False]: + model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl", gradient_checkpointing=checkpointing) + model.to(torch_device) + input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog + # fmt: off + expected_output_ids = [464, 3290, 12, 3380, 4866, 286, 262, 1492, 11, 543, 318, 257, 4947, 286, 27126, 416, 262, 2739, 1772, 11] # The dog-eared copy of the book, which is a collection of essays by the late author, + # fmt: on + output_ids = model.generate(input_ids, do_sample=False) + self.assertListEqual(output_ids[0].tolist(), expected_output_ids) + + @slow + def test_gpt_neo_sample(self): + tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt_neo_xl") + model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl") + model.to(torch_device) + + torch.manual_seed(0) + tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) + input_ids = tokenized.input_ids.to(torch_device) + output_ids = model.generate(input_ids, do_sample=True) + output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + EXPECTED_OUTPUT_STR = "Today is a nice day and if you don’t get the memo here is what you can" + self.assertEqual(output_str, EXPECTED_OUTPUT_STR)