mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
GPT Neo (#10848)
* lets begin * boom boom * fix out proj in attn * fix attention * fix local attention * add tokenizer * fix imports * autotokenizer * fix checkpoint name * cleanup * more clean-up * more cleanup * output attentions * fix attn mask creation * fix imports * config doc * add tests * add slow tests * quality * add conversion script * copyright * typo * another bites the dust * fix attention tests * doc * add embed init in convert function * fix copies * remove tokenizer * enable caching * address review comments * improve config and create attn layer list internally * more consistent naming * init hf config from mesh-tf config json file * remove neo tokenizer from doc * handle attention_mask in local attn layer * attn_layers => attention_layers * add tokenizer_class in config * fix docstring * raise if len of attention_layers is not same as num_layers * remove tokenizer_class from config * more consistent naming * fix doc * fix checkpoint names * fp16 compat * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
a04eb8d369
commit
860264379f
@ -213,6 +213,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[Funnel Transformer](https://huggingface.co/transformers/model_doc/funnel.html)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
|
||||
1. **[GPT](https://huggingface.co/transformers/model_doc/gpt.html)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
|
||||
1. **[GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
|
||||
1. **[GPT Neo](https://huggingface.co/transformers/model_doc/gpt_neo.html)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
|
||||
1. **[I-BERT](https://huggingface.co/transformers/model_doc/ibert.html)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer
|
||||
1. **[LayoutLM](https://huggingface.co/transformers/model_doc/layoutlm.html)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
1. **[LED](https://huggingface.co/transformers/model_doc/led.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
|
@ -151,79 +151,81 @@ and conversion utilities for the following models:
|
||||
22. :doc:`GPT-2 <model_doc/gpt2>` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask
|
||||
Learners <https://blog.openai.com/better-language-models/>`__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David
|
||||
Luan, Dario Amodei** and Ilya Sutskever**.
|
||||
23. :doc:`I-BERT <model_doc/ibert>` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization
|
||||
23. :doc:`GPT Neo <model_doc/gpt_neo>` (from EleutherAI) released in the repository `EleutherAI/gpt-neo
|
||||
<https://github.com/EleutherAI/gpt-neo>`__ by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
|
||||
24. :doc:`I-BERT <model_doc/ibert>` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization
|
||||
<https://arxiv.org/abs/2101.01321>`__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer
|
||||
24. :doc:`LayoutLM <model_doc/layoutlm>` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training
|
||||
25. :doc:`LayoutLM <model_doc/layoutlm>` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training
|
||||
of Text and Layout for Document Image Understanding <https://arxiv.org/abs/1912.13318>`__ by Yiheng Xu, Minghao Li,
|
||||
Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
25. :doc:`LED <model_doc/led>` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer
|
||||
26. :doc:`LED <model_doc/led>` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer
|
||||
<https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
26. :doc:`Longformer <model_doc/longformer>` (from AllenAI) released with the paper `Longformer: The Long-Document
|
||||
27. :doc:`Longformer <model_doc/longformer>` (from AllenAI) released with the paper `Longformer: The Long-Document
|
||||
Transformer <https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
27. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
|
||||
28. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
|
||||
Encoder Representations from Transformers for Open-Domain Question Answering <https://arxiv.org/abs/1908.07490>`__
|
||||
by Hao Tan and Mohit Bansal.
|
||||
28. :doc:`M2M100 <model_doc/m2m_100>` (from Facebook) released with the paper `Beyond English-Centric Multilingual
|
||||
29. :doc:`M2M100 <model_doc/m2m_100>` (from Facebook) released with the paper `Beyond English-Centric Multilingual
|
||||
Machine Translation <https://arxiv.org/abs/2010.11125>`__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi
|
||||
Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman
|
||||
Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
|
||||
29. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
|
||||
30. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
|
||||
Jörg Tiedemann. The `Marian Framework <https://marian-nmt.github.io/>`__ is being developed by the Microsoft
|
||||
Translator Team.
|
||||
30. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||
31. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||
Neural Machine Translation <https://arxiv.org/abs/2001.08210>`__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li,
|
||||
Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||
31. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
|
||||
32. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
|
||||
Multilingual Pretraining and Finetuning <https://arxiv.org/abs/2008.00401>`__ by Yuqing Tang, Chau Tran, Xian Li,
|
||||
Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
|
||||
32. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||
33. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||
Pre-training for Language Understanding <https://arxiv.org/abs/2004.09297>`__ by Kaitao Song, Xu Tan, Tao Qin,
|
||||
Jianfeng Lu, Tie-Yan Liu.
|
||||
33. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||
34. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||
text-to-text transformer <https://arxiv.org/abs/2010.11934>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir
|
||||
Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
34. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||
35. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||
Gap-sentences for Abstractive Summarization <https://arxiv.org/abs/1912.08777>`__> by Jingqing Zhang, Yao Zhao,
|
||||
Mohammad Saleh and Peter J. Liu.
|
||||
35. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||
36. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||
Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Weizhen Qi,
|
||||
Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
36. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||
37. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||
Transformer <https://arxiv.org/abs/2001.04451>`__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||
37. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
38. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
|
||||
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
38. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||
39. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
|
||||
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
||||
39. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
40. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
|
||||
Krishna, and Kurt W. Keutzer.
|
||||
40. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
41. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
41. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
42. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||
Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
42. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
43. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
43. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
44. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
44. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
45. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||
45. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
46. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
46. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
47. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||
Zettlemoyer and Veselin Stoyanov.
|
||||
47. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
48. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
48. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
49. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
|
||||
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
|
||||
@ -280,6 +282,8 @@ TensorFlow and/or Flax.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
@ -443,6 +447,7 @@ TensorFlow and/or Flax.
|
||||
model_doc/mt5
|
||||
model_doc/gpt
|
||||
model_doc/gpt2
|
||||
model_doc/gpt_neo
|
||||
model_doc/pegasus
|
||||
model_doc/phobert
|
||||
model_doc/prophetnet
|
||||
|
65
docs/source/model_doc/gpt_neo.rst
Normal file
65
docs/source/model_doc/gpt_neo.rst
Normal file
@ -0,0 +1,65 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
GPT Neo
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The GPTNeo model was released in the `EleutherAI/gpt-neo <https://github.com/EleutherAI/gpt-neo>`__ repository by Sid
|
||||
Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. It is a GPT2 like causal language model trained on the
|
||||
`Pile <https://pile.eleuther.ai/>`__ dataset.
|
||||
|
||||
The architecture is similar to GPT2 except that GPT Neo uses local attention in every other layer with a window size of
|
||||
256 tokens.
|
||||
|
||||
Generation
|
||||
_______________________________________________________________________________________________________________________
|
||||
|
||||
The :obj:`generate()` method can be used to generate text using GPT Neo model.
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from transformers import GPTNeoForCausalLM, GPT2Tokenizer
|
||||
>>> model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl")
|
||||
>>> tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt_neo_xl")
|
||||
|
||||
>>> prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
|
||||
... "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
|
||||
... "researchers was the fact that the unicorns spoke perfect English."
|
||||
|
||||
>>> input_ids = tokenizer(unicorns, return_tensors="pt").input_ids
|
||||
|
||||
>>> gen_tokens = model.generate(ids, do_sample=True, temperature=0.9, max_length=100,)
|
||||
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
|
||||
|
||||
|
||||
GPTNeoConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.GPTNeoConfig
|
||||
:members:
|
||||
|
||||
|
||||
GPTNeoModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.GPTNeoModel
|
||||
:members: forward
|
||||
|
||||
|
||||
GPTNeoForCausalLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.GPTNeoForCausalLM
|
||||
:members: forward
|
@ -139,6 +139,12 @@ For the full list, refer to `https://huggingface.co/models <https://huggingface.
|
||||
| | ``gpt2-xl`` | | 48-layer, 1600-hidden, 25-heads, 1558M parameters. |
|
||||
| | | | OpenAI's XL-sized GPT-2 English model |
|
||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| GPTNeo | ``EleutherAI/gpt_neo_xl`` | | 24-layer, 2048-hidden, 16-heads, 1.3B parameters. |
|
||||
| | | | EleutherAI's GPT-3 like language model. |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``EleutherAI/gpt_neo_2.7B`` | | 32-layer, 2560-hidden, 20-heads, 2.7B parameters. |
|
||||
| | | | EleutherAI's GPT-3 like language model. |
|
||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| Transformer-XL | ``transfo-xl-wt103`` | | 18-layer, 1024-hidden, 16-heads, 257M parameters. |
|
||||
| | | | English model trained on wikitext-103 |
|
||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
@ -177,6 +177,7 @@ _import_structure = {
|
||||
"models.fsmt": ["FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FSMTConfig", "FSMTTokenizer"],
|
||||
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"],
|
||||
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"],
|
||||
"models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
|
||||
"models.herbert": ["HerbertTokenizer"],
|
||||
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
|
||||
"models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"],
|
||||
@ -407,6 +408,7 @@ if is_torch_available():
|
||||
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
|
||||
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
|
||||
# PyTorch models structure
|
||||
|
||||
_import_structure["models.albert"].extend(
|
||||
[
|
||||
"ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -652,6 +654,15 @@ if is_torch_available():
|
||||
"load_tf_weights_in_gpt2",
|
||||
]
|
||||
)
|
||||
_import_structure["models.gpt_neo"].extend(
|
||||
[
|
||||
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"GPTNeoForCausalLM",
|
||||
"GPTNeoModel",
|
||||
"GPTNeoPreTrainedModel",
|
||||
"load_tf_weights_in_gpt_neo",
|
||||
]
|
||||
)
|
||||
_import_structure["models.ibert"].extend(
|
||||
[
|
||||
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -1420,6 +1431,7 @@ if TYPE_CHECKING:
|
||||
from .models.fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig, FSMTTokenizer
|
||||
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer
|
||||
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
|
||||
from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
|
||||
from .models.herbert import HerbertTokenizer
|
||||
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
||||
from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer
|
||||
@ -1835,6 +1847,13 @@ if TYPE_CHECKING:
|
||||
GPT2PreTrainedModel,
|
||||
load_tf_weights_in_gpt2,
|
||||
)
|
||||
from .models.gpt_neo import (
|
||||
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoModel,
|
||||
GPTNeoPreTrainedModel,
|
||||
load_tf_weights_in_gpt_neo,
|
||||
)
|
||||
from .models.ibert import (
|
||||
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
IBertForMaskedLM,
|
||||
|
@ -41,6 +41,7 @@ from . import (
|
||||
fsmt,
|
||||
funnel,
|
||||
gpt2,
|
||||
gpt_neo,
|
||||
herbert,
|
||||
layoutlm,
|
||||
led,
|
||||
|
@ -41,6 +41,7 @@ from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE
|
||||
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
|
||||
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
|
||||
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from ..gpt_neo.configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
|
||||
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
||||
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
|
||||
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
|
||||
@ -81,6 +82,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
# Add archive maps here
|
||||
GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -129,6 +131,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
CONFIG_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("gpt_neo", GPTNeoConfig),
|
||||
("big_bird", BigBirdConfig),
|
||||
("speech_to_text", Speech2TextConfig),
|
||||
("wav2vec2", Wav2Vec2Config),
|
||||
@ -183,6 +186,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("gpt_neo", "GPT Neo"),
|
||||
("big_bird", "BigBird"),
|
||||
("speech_to_text", "Speech2Text"),
|
||||
("wav2vec2", "Wav2Vec2"),
|
||||
|
@ -21,8 +21,6 @@ from collections import OrderedDict
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...utils import logging
|
||||
|
||||
# Add modeling imports here
|
||||
from ..albert.modeling_albert import (
|
||||
AlbertForMaskedLM,
|
||||
AlbertForMultipleChoice,
|
||||
@ -137,6 +135,9 @@ from ..funnel.modeling_funnel import (
|
||||
FunnelModel,
|
||||
)
|
||||
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
|
||||
|
||||
# Add modeling imports here
|
||||
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoModel
|
||||
from ..ibert.modeling_ibert import (
|
||||
IBertForMaskedLM,
|
||||
IBertForMultipleChoice,
|
||||
@ -289,6 +290,7 @@ from .configuration_auto import (
|
||||
FSMTConfig,
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
GPTNeoConfig,
|
||||
IBertConfig,
|
||||
LayoutLMConfig,
|
||||
LEDConfig,
|
||||
@ -326,6 +328,7 @@ logger = logging.get_logger(__name__)
|
||||
MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(GPTNeoConfig, GPTNeoModel),
|
||||
(BigBirdConfig, BigBirdModel),
|
||||
(Speech2TextConfig, Speech2TextModel),
|
||||
(Wav2Vec2Config, Wav2Vec2Model),
|
||||
@ -458,6 +461,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
(GPTNeoConfig, GPTNeoForCausalLM),
|
||||
(BigBirdConfig, BigBirdForCausalLM),
|
||||
(CamembertConfig, CamembertForCausalLM),
|
||||
(XLMRobertaConfig, XLMRobertaForCausalLM),
|
||||
|
70
src/transformers/models/gpt_neo/__init__.py
Normal file
70
src/transformers/models/gpt_neo/__init__.py
Normal file
@ -0,0 +1,70 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
|
||||
"tokenization_gpt_neo": ["GPTNeoTokenizer"],
|
||||
}
|
||||
|
||||
if is_tokenizers_available():
|
||||
_import_structure["tokenization_gpt_neo_fast"] = ["GPTNeoTokenizerFast"]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_gpt_neo"] = [
|
||||
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"GPTNeoForCausalLM",
|
||||
"GPTNeoModel",
|
||||
"GPTNeoPreTrainedModel",
|
||||
"load_tf_weights_in_gpt_neo",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_gpt_neo import (
|
||||
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoModel,
|
||||
GPTNeoPreTrainedModel,
|
||||
load_tf_weights_in_gpt_neo,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
class _LazyModule(_BaseLazyModule):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
__file__ = globals()["__file__"]
|
||||
__path__ = [os.path.dirname(__file__)]
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
|
175
src/transformers/models/gpt_neo/configuration_gpt_neo.py
Normal file
175
src/transformers/models/gpt_neo/configuration_gpt_neo.py
Normal file
@ -0,0 +1,175 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" GPT Neo model configuration """
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"EleutherAI/gpt_neo_xl": "https://huggingface.co/EleutherAI/gpt_neo_xl/resolve/main/config.json",
|
||||
# See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo
|
||||
}
|
||||
|
||||
|
||||
class GPTNeoConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.GPTNeoModel`. It is used to
|
||||
instantiate a GPT Neo model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the GPTNeo `gpt_neo_xl
|
||||
<https://huggingface.co/EleutherAI/gpt_neo_xl>`__ architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 50257):
|
||||
Vocabulary size of the GPT Neo model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.GPTNeoModel`. Vocabulary size of the model.
|
||||
Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of
|
||||
:class:`~transformers.GPTNeoModel`.
|
||||
attention_types (:obj:`List`, `optional`, defaults to :obj:`[[["global", "local"], 12]]`):
|
||||
The type of attention for each layer in a :obj:`List` of the following format :obj:`[[["attention_type"],
|
||||
num_layerss]]` e.g. for a 24 layer model :obj:`[[["global"], 24]]` or :obj:`[[["global", "local"], 12]]`
|
||||
Choose the value of ``attention_type`` from :obj:`["global", "local"]`
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 2048):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_layers (:obj:`int`, `optional`, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 8192):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_new"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
embed_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
|
||||
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.GPTNeoModel`.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if ``config.is_decoder=True``.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import GPTNeoModel, GPTNeoConfig
|
||||
|
||||
>>> # Initializing a GPTNeo EleutherAI/gpt_neo_xl style configuration
|
||||
>>> configuration = GPTNeoConfig()
|
||||
|
||||
>>> # Initializing a model from the EleutherAI/gpt_neo_xl style configuration
|
||||
>>> model = GPTNeoModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
model_type = "gpt_neo"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50257,
|
||||
max_position_embeddings=2048,
|
||||
hidden_size=2048,
|
||||
num_layers=24,
|
||||
attention_types=[[["global", "local"], 12]],
|
||||
num_heads=16,
|
||||
intermediate_size=None,
|
||||
window_size=256,
|
||||
activation_function="gelu_new",
|
||||
resid_dropout=0.0,
|
||||
embed_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
summary_type="cls_index",
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
bos_token_id=50256,
|
||||
eos_token_id=50256,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.window_size = window_size
|
||||
self.activation_function = activation_function
|
||||
self.resid_dropout = resid_dropout
|
||||
self.embed_dropout = embed_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.summary_type = summary_type
|
||||
self.summary_use_proj = summary_use_proj
|
||||
self.summary_activation = summary_activation
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.summary_proj_to_labels = summary_proj_to_labels
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.use_cache = use_cache
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
self.attention_types = attention_types
|
||||
self.attention_layers = self.expand_attention_types_params(attention_types)
|
||||
|
||||
if len(self.attention_layers) != self.num_layers:
|
||||
raise ValueError(
|
||||
"Configuration for convolutional module is incorrect."
|
||||
"It is required that `len(config.attention_layers)` == `config.num_layers`"
|
||||
f"but is `len(config.attention_layers) = {len(self.attention_layers)}`,"
|
||||
f"`config.num_layers = {self.num_layers}`."
|
||||
"`config.attention_layers` is prepared using `config.attention_types`."
|
||||
"Please verify the value of `config.attention_types` argument."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def expand_attention_types_params(attention_types):
|
||||
attentions = []
|
||||
for item in attention_types:
|
||||
for _ in range(item[1]):
|
||||
attentions.extend(item[0])
|
||||
return attentions
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.num_heads
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.num_layers
|
@ -0,0 +1,70 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Eleuther AI and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert GPT Neo checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from transformers import GPTNeoConfig, GPTNeoForCausalLM, load_tf_weights_in_gpt_neo
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config_json = json.load(open(config_file, "r"))
|
||||
config = GPTNeoConfig(
|
||||
hidden_size=config_json["n_embd"],
|
||||
num_layers=config_json["n_layer"],
|
||||
num_heads=config_json["n_head"],
|
||||
attention_types=config_json["attention_types"],
|
||||
max_position_embeddings=config_json["n_ctx"],
|
||||
resid_dropout=config_json["res_dropout"],
|
||||
embed_dropout=config_json["embed_dropout"],
|
||||
attention_dropout=config_json["attn_dropout"],
|
||||
)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = GPTNeoForCausalLM(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_gpt_neo(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained mesh-tf model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
964
src/transformers/models/gpt_neo/modeling_gpt_neo.py
Executable file
964
src/transformers/models/gpt_neo/modeling_gpt_neo.py
Executable file
@ -0,0 +1,964 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch GPT Neo model. """
|
||||
|
||||
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_gpt_neo import GPTNeoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "GPTNeoConfig"
|
||||
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
||||
|
||||
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"EleutherAI/gpt_neo_xl",
|
||||
# See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo
|
||||
]
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "EleutherAI/gpt_neo_xl"
|
||||
|
||||
|
||||
def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
|
||||
"""Load tf checkpoints in a pytorch model"""
|
||||
try:
|
||||
import re
|
||||
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise
|
||||
tf_path = os.path.abspath(gpt_neo_checkpoint_path)
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
if "global_step" not in name and "adam" not in name:
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy()
|
||||
name = name.replace("attn/q", "attn/attention/q_proj/w")
|
||||
name = name.replace("attn/k", "attn/attention/k_proj/w")
|
||||
name = name.replace("attn/v", "attn/attention/v_proj/w")
|
||||
name = name.replace("attn/o", "attn/attention/out_proj/w")
|
||||
name = name.replace("norm_1", "ln_1")
|
||||
name = name.replace("norm_2", "ln_2")
|
||||
name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b")
|
||||
name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w")
|
||||
name = name.replace("conv1d_main/c_fc/bias", "c_fc/b")
|
||||
name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w")
|
||||
name = name.replace("conv1d_main/c_proj/bias", "c_proj/b")
|
||||
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name[5:] # skip "gpt2/"
|
||||
name = name.split("/")
|
||||
pointer = model.transformer
|
||||
for m_name in name:
|
||||
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
|
||||
scope_names = re.split(r"(\d+)", m_name)
|
||||
else:
|
||||
scope_names = [m_name]
|
||||
if scope_names[0] == "w" or scope_names[0] == "g":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif scope_names[0] == "b":
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif scope_names[0] == "wpe" or scope_names[0] == "wte":
|
||||
pointer = getattr(pointer, scope_names[0])
|
||||
pointer = getattr(pointer, "weight")
|
||||
else:
|
||||
pointer = getattr(pointer, scope_names[0])
|
||||
if len(scope_names) >= 2:
|
||||
num = int(scope_names[1])
|
||||
pointer = pointer[num]
|
||||
|
||||
if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]:
|
||||
array = array.transpose()
|
||||
|
||||
try:
|
||||
assert (
|
||||
pointer.shape == array.shape
|
||||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}"
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
|
||||
# init the final linear layer using word embeddings
|
||||
embs = model.transformer.wte.weight
|
||||
lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False)
|
||||
lin.weight = embs
|
||||
model.set_output_embeddings(lin)
|
||||
return model
|
||||
|
||||
|
||||
class GPTNeoSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attention_dropout)
|
||||
self.resid_dropout = nn.Dropout(config.resid_dropout)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
assert (
|
||||
self.head_dim * self.num_heads == self.embed_dim
|
||||
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
||||
|
||||
def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
|
||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
|
||||
attn_weights = torch.matmul(q, k)
|
||||
nd, ns = attn_weights.size(-2), attn_weights.size(-1)
|
||||
|
||||
mask = self.bias[:, :, ns - nd : ns, :ns]
|
||||
attn_weights = torch.where(mask.bool(), attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
||||
attn_weights = attn_weights.to(v.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
outputs = (torch.matmul(attn_weights, v),)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
return outputs
|
||||
|
||||
def merge_heads(self, x):
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
||||
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
||||
|
||||
def split_heads(self, x, k=False):
|
||||
new_x_shape = x.size()[:-1] + (self.num_heads, x.size(-1) // self.num_heads)
|
||||
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
||||
if k:
|
||||
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
|
||||
else:
|
||||
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
|
||||
query = self.q_proj(hidden_states)
|
||||
key = self.k_proj(hidden_states)
|
||||
value = self.v_proj(hidden_states)
|
||||
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
||||
key = torch.cat((past_key, key), dim=-1)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key.transpose(-2, -1), value) # transpose to have same shapes
|
||||
else:
|
||||
present = None
|
||||
|
||||
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
|
||||
a = attn_outputs[0]
|
||||
|
||||
a = self.merge_heads(a)
|
||||
a = self.out_proj(a)
|
||||
a = self.resid_dropout(a)
|
||||
|
||||
return (a, present) + attn_outputs[1:] # a, present, (attentions)
|
||||
|
||||
|
||||
class GPTNeoLocalSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attention_dropout)
|
||||
self.resid_dropout = nn.Dropout(config.resid_dropout)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
assert (
|
||||
self.head_dim * self.num_heads == self.embed_dim
|
||||
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
||||
|
||||
self.window_size = config.window_size
|
||||
|
||||
def shift(self, x, offset, pad_value=0, dim=2):
|
||||
t = x.shape[1]
|
||||
dims = (len(x.shape) - dim) * (0, 0)
|
||||
padded_x = F.pad(x, (*dims, offset, 0), value=pad_value)
|
||||
return padded_x[:, :t, ...]
|
||||
|
||||
def look_around(self, x, block_length, window_size):
|
||||
num_complete_blocks = window_size // block_length
|
||||
|
||||
parts = [x]
|
||||
for i in range(1, num_complete_blocks + 1):
|
||||
parts = [self.shift(x, i)] + parts
|
||||
|
||||
partial_size = window_size % block_length
|
||||
if partial_size > 0:
|
||||
margin = x[:, :, block_length - partial_size : block_length, ...]
|
||||
parts = [self.shift(margin, num_complete_blocks + 1)] + parts
|
||||
return torch.cat(parts, dim=2)
|
||||
|
||||
def split_heads(self, x, k=False):
|
||||
new_x_shape = x.size()[:-1] + (self.num_heads, x.size(-1) // self.num_heads)
|
||||
x = x.view(*new_x_shape)
|
||||
if k:
|
||||
return x.permute(0, 1, 3, 4, 2) # (batch, chunks, head, head_features, seq_length)
|
||||
else:
|
||||
return x.permute(0, 1, 3, 2, 4) # (batch, chunks, head, seq_length, head_features)
|
||||
|
||||
def merge_heads(self, x):
|
||||
x = x.permute(0, 1, 3, 2, 4).contiguous()
|
||||
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
||||
return x.view(*new_x_shape)
|
||||
|
||||
def _split_seq_length_dim_to(self, tensors, num_blocks, block_length):
|
||||
return tensors.reshape(tensors.size()[0], num_blocks, block_length, -1)
|
||||
|
||||
def create_attention_mask(self, bs, seq_len, windows, block_length, attention_mask):
|
||||
ticker = torch.arange(seq_len)[None, :]
|
||||
b_t = ticker.reshape(1, windows, block_length)
|
||||
|
||||
bq_t = b_t
|
||||
bq_k = self.look_around(b_t, block_length, self.window_size)
|
||||
|
||||
# compute attn mask
|
||||
# this matches the original implem in mess-tensorflow
|
||||
# https://github.com/tensorflow/mesh/blob/8bd599a21bad01cef1300a8735c17306ce35db6e/mesh_tensorflow/transformer/attention.py#L805
|
||||
relative_position = bq_k.unsqueeze(-2) - bq_t.unsqueeze(-1)
|
||||
relative_position = relative_position.transpose(-1, -2)
|
||||
|
||||
sequence_id = torch.ones(bs, seq_len)
|
||||
q_seq = sequence_id.reshape(-1, windows, block_length)
|
||||
m_seq = sequence_id.reshape(-1, windows, block_length)
|
||||
m_seq = self.look_around(m_seq, block_length, self.window_size)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(m_seq.device)
|
||||
attention_mask = attention_mask.reshape(-1, windows, block_length)
|
||||
attention_mask = self.look_around(attention_mask, block_length, self.window_size)
|
||||
m_seq *= attention_mask
|
||||
|
||||
visible = torch.eq(q_seq.unsqueeze(-1), m_seq.unsqueeze(-2)).transpose(-1, -2)
|
||||
visible = torch.logical_and(visible, torch.gt(relative_position, -self.window_size))
|
||||
mask = torch.logical_and(visible, torch.less_equal(relative_position, 0)).transpose(-1, -2).unsqueeze(2)
|
||||
return mask
|
||||
|
||||
def _attn(self, q, k, v, causal_mask, head_mask=None, output_attentions=False):
|
||||
# attn
|
||||
|
||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
|
||||
attn_weights = torch.matmul(q, k)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||
|
||||
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
||||
attn_weights = attn_weights.to(v.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, v)
|
||||
|
||||
outputs = (attn_output,)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
return outputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
query = self.q_proj(hidden_states)
|
||||
|
||||
if layer_past is not None:
|
||||
past = layer_past[0]
|
||||
key_value_hidden_states = torch.cat([past, hidden_states], dim=1)
|
||||
past_length = past.size()[1]
|
||||
else:
|
||||
key_value_hidden_states = hidden_states
|
||||
past_length = 0
|
||||
|
||||
key = self.k_proj(key_value_hidden_states)
|
||||
value = self.v_proj(key_value_hidden_states)
|
||||
|
||||
# compute block length and windows
|
||||
bs, seq_len = hidden_states.shape[:2]
|
||||
full_seq_length = seq_len + past_length
|
||||
block_length = self.window_size
|
||||
while full_seq_length % block_length != 0:
|
||||
block_length -= 1
|
||||
num_blocks = full_seq_length // block_length
|
||||
|
||||
# create buckets
|
||||
if layer_past is not None:
|
||||
# we just need 1 window with block_length 1 when caching is enabled
|
||||
query = self._split_seq_length_dim_to(query, 1, 1)
|
||||
else:
|
||||
query = self._split_seq_length_dim_to(query, num_blocks, block_length)
|
||||
|
||||
key = self._split_seq_length_dim_to(key, num_blocks, block_length)
|
||||
value = self._split_seq_length_dim_to(value, num_blocks, block_length)
|
||||
|
||||
key = self.look_around(key, block_length, self.window_size)
|
||||
value = self.look_around(value, block_length, self.window_size)
|
||||
|
||||
# select key/value vectors only for the last window
|
||||
if layer_past is not None:
|
||||
key = key[:, -1:, ...]
|
||||
value = value[:, -1:, ...]
|
||||
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
|
||||
mask = self.create_attention_mask(bs, full_seq_length, num_blocks, block_length, attention_mask)
|
||||
if layer_past is not None:
|
||||
mask = mask[:, -1:, :, -1:, :] # only take the mask for the last window
|
||||
mask = mask.to(hidden_states.device)
|
||||
|
||||
# attn
|
||||
attn_outputs = self._attn(query, key, value, mask, head_mask, output_attentions)
|
||||
attn = attn_outputs[0]
|
||||
|
||||
attn = self.merge_heads(attn)
|
||||
attn = attn.reshape(bs, seq_len, self.embed_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
attn = self.resid_dropout(attn)
|
||||
return (attn,) + attn_outputs[1:]
|
||||
|
||||
|
||||
class GPTNeoAttention(nn.Module):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.attention_layers = config.attention_layers
|
||||
self.attention_type = self.attention_layers[layer_id]
|
||||
|
||||
if self.attention_type == "global":
|
||||
self.attention = GPTNeoSelfAttention(config)
|
||||
elif self.attention_type == "local":
|
||||
self.attention = GPTNeoLocalSelfAttention(config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: {}. Select attn layer types from ['global', 'local'] only.".format(
|
||||
self.attention_layers
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
outputs = self.attention(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# cache the hidden_states instead of key_value_states
|
||||
# for local attention layer
|
||||
if self.attention_type == "local":
|
||||
if layer_past is None:
|
||||
past = hidden_states
|
||||
else:
|
||||
past = torch.cat([layer_past[0], hidden_states], dim=1)
|
||||
outputs = (outputs[0], (past,)) + outputs[1:]
|
||||
return outputs
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * hidden_size
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
self.c_fc = nn.Linear(embed_dim, intermediate_size)
|
||||
self.c_proj = nn.Linear(intermediate_size, embed_dim)
|
||||
self.act = ACT2FN[config.activation_function]
|
||||
self.dropout = nn.Dropout(config.resid_dropout)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.act(self.c_fc(x))
|
||||
h2 = self.c_proj(h)
|
||||
return self.dropout(h2)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTNeoAttention(config, layer_id)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = MLP(inner_dim, config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
attn_outputs = self.attn(
|
||||
self.ln_1(hidden_states),
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||
outputs = attn_outputs[1:]
|
||||
# residual connection
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
|
||||
# residual connection
|
||||
hidden_states = hidden_states + feed_forward_hidden_states
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||
|
||||
|
||||
class GPTNeoPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = GPTNeoConfig
|
||||
load_tf_weights = load_tf_weights_in_gpt_neo
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear,)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
GPT_NEO_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
|
||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.GPTNeoConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
||||
weights.
|
||||
"""
|
||||
|
||||
GPT_NEO_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
||||
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
|
||||
``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
|
||||
sequence tokens in the vocabulary.
|
||||
|
||||
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
|
||||
passed as ``input_ids``.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.GPTNeoTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.num_layers`):
|
||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
||||
have their past given to this model should not be passed as ``input_ids`` as they have already been
|
||||
computed.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 0 corresponds to a `sentence A` token,
|
||||
- 1 corresponds to a `sentence B` token.
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||
config.max_position_embeddings - 1]``.
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
|
||||
If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
|
||||
:obj:`past_key_values`).
|
||||
use_cache (:obj:`bool`, `optional`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`).
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare GPTNeo Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
GPT_NEO_START_DOCSTRING,
|
||||
)
|
||||
class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.drop = nn.Dropout(config.embed_dropout)
|
||||
self.h = nn.ModuleList([Block(config, layer_id=i) for i in range(config.num_layers)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.wte = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
assert batch_size > 0, "batch_size has to be defined and > 0"
|
||||
global_attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
global_attention_mask = global_attention_mask[:, None, None, :]
|
||||
|
||||
# Since global_attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
global_attention_mask = global_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
global_attention_mask = (1.0 - global_attention_mask) * -10000.0
|
||||
else:
|
||||
global_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x num_headss x N x N
|
||||
# head_mask has shape n_layer x batch x num_headss x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
attn_type = self.config.attention_layers[i]
|
||||
attn_mask = global_attention_mask if attn_type == "global" else attention_mask
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warn(
|
||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
||||
"`use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attn_mask,
|
||||
head_mask[i],
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attn_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(*output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
embeddings).
|
||||
""",
|
||||
GPT_NEO_START_DOCSTRING,
|
||||
)
|
||||
class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||
_keys_to_ignore_on_save = [r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = GPTNeoModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutputWithCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
|
||||
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Compute loss in fp32 to match with mesh-tf version
|
||||
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
|
||||
lm_logits = lm_logits.to(torch.float32)
|
||||
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||
loss = loss.to(hidden_states.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` cache if
|
||||
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
@ -1449,6 +1449,36 @@ def load_tf_weights_in_gpt2(*args, **kwargs):
|
||||
requires_pytorch(load_tf_weights_in_gpt2)
|
||||
|
||||
|
||||
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class GPTNeoForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class GPTNeoModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class GPTNeoPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
def load_tf_weights_in_gpt_neo(*args, **kwargs):
|
||||
requires_pytorch(load_tf_weights_in_gpt_neo)
|
||||
|
||||
|
||||
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
511
tests/test_modeling_gpt_neo.py
Normal file
511
tests/test_modeling_gpt_neo.py
Normal file
@ -0,0 +1,511 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch GPT Neo model. """
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPT2Tokenizer,
|
||||
GPTNeoConfig,
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoModel,
|
||||
)
|
||||
|
||||
|
||||
class GPTNeoModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_token_type_ids=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
use_mc_token_ids=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=4,
|
||||
attention_types=[[["global", "local"], 2]],
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
window_size=7,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.use_mc_token_ids = use_mc_token_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.window_size = window_size
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
self.pad_token_id = vocab_size - 1
|
||||
self.chunk_length = window_size
|
||||
self.attention_types = attention_types
|
||||
|
||||
def get_large_model_config(self):
|
||||
return GPTNeoConfig.from_pretrained("gpt_neo")
|
||||
|
||||
def prepare_config_and_inputs(self, gradient_checkpointing=False):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
mc_token_ids = None
|
||||
if self.use_mc_token_ids:
|
||||
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = GPTNeoConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
use_cache=not gradient_checkpointing,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
window_size=self.window_size,
|
||||
attention_types=self.attention_types,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_gpt_neo_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = GPTNeoModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
# past_key_values is not implemented
|
||||
# self.parent.assertEqual(len(result.past_key_values), config.n_layer)
|
||||
|
||||
def create_and_check_gpt_neo_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = GPTNeoModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
|
||||
outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past = outputs.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size)
|
||||
|
||||
# append to next input_ids and token_type_ids
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = GPTNeoForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = GPTNeoForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
result.loss.backward()
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (GPTNeoModel, GPTNeoForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
|
||||
test_missing_keys = False
|
||||
test_pruning = False
|
||||
test_model_parallel = False
|
||||
|
||||
# special case for DoubleHeads model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = GPTNeoModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=GPTNeoConfig, n_embd=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_gpt_neo_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_model(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_model_past(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_lm_head_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_gradient_checkpointing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
|
||||
|
||||
def _get_local_attn_seq_len_block_len_windows(self, seq_len, window_size):
|
||||
block_length = window_size
|
||||
while seq_len % block_length != 0:
|
||||
block_length -= 1
|
||||
windows = seq_len // block_length
|
||||
local_seq_len = window_size + block_length
|
||||
return local_seq_len, block_length, windows
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# test global attention shape
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, seq_len],
|
||||
)
|
||||
# test local attention shape
|
||||
encoder_key_length = self._get_local_attn_seq_len_block_len_windows(seq_len, chunk_length)[0]
|
||||
self.assertListEqual(
|
||||
list(attentions[-1].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, seq_len, encoder_key_length],
|
||||
)
|
||||
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# test global attention shape
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, seq_len],
|
||||
)
|
||||
|
||||
# test local attention shape
|
||||
self.assertListEqual(
|
||||
list(self_attentions[-1].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, seq_len, encoder_key_length],
|
||||
)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
src_len = min_length + idx
|
||||
global_expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
)
|
||||
|
||||
local_seq_len, block_len, windows = self._get_local_attn_seq_len_block_len_windows(
|
||||
src_len, config.window_size
|
||||
)
|
||||
block_len = 1 if use_cache else block_len
|
||||
local_expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
windows,
|
||||
config.num_attention_heads,
|
||||
block_len,
|
||||
local_seq_len,
|
||||
)
|
||||
|
||||
shapes = [layer_attention.shape for layer_attention in iter_attentions]
|
||||
# every other layer is local attention layers
|
||||
# so alternate between expected shapes
|
||||
expected_shape = [
|
||||
global_expected_shape if i % 2 == 0 else local_expected_shape for i, _ in enumerate(iter_attentions)
|
||||
]
|
||||
# check attn size
|
||||
self.assertListEqual(shapes, expected_shape)
|
||||
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl")
|
||||
model.to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Define PAD Token = EOS Token = 50256
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
# use different length sentences to test batching
|
||||
sentences = [
|
||||
"Hello, my dog is a little",
|
||||
"Today, I am",
|
||||
]
|
||||
|
||||
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||||
input_ids = inputs["input_ids"].to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||
)
|
||||
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||
|
||||
expected_output_sentence = [
|
||||
"Hello, my dog is a little bit of a kitty. She is a very sweet and loving",
|
||||
"Today, I am going to talk about the best way to get a job in the",
|
||||
]
|
||||
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = GPTNeoModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_gpt_neo(self):
|
||||
for checkpointing in [True, False]:
|
||||
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl", gradient_checkpointing=checkpointing)
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
# fmt: off
|
||||
expected_output_ids = [464, 3290, 12, 3380, 4866, 286, 262, 1492, 11, 543, 318, 257, 4947, 286, 27126, 416, 262, 2739, 1772, 11] # The dog-eared copy of the book, which is a collection of essays by the late author,
|
||||
# fmt: on
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_gpt_neo_sample(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt_neo_xl")
|
||||
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl")
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
|
||||
input_ids = tokenized.input_ids.to(torch_device)
|
||||
output_ids = model.generate(input_ids, do_sample=True)
|
||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUT_STR = "Today is a nice day and if you don’t get the memo here is what you can"
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
Loading…
Reference in New Issue
Block a user