mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add XGLM models (#14876)
* add xglm * update vocab size * fix model name * style and tokenizer * typo * no mask token * fix pos embed compute * fix args * fix tokenizer * fix positions * fix tokenization * style and dic fixes * fix imports * add fast tokenizer * update names * add pt tests * fix tokenizer * fix typo * fix tokenizer import * fix fast tokenizer * fix tokenizer * fix converter * add tokenizer test * update checkpoint names * fix tokenizer tests * fix slow tests * add copied from comments * rst -> mdx * flax model * update flax tests * quality * style * doc * update index and readme * fix copies * fix doc * update toctrr * fix indent * minor fixes * fix config doc * don't save embed_pos weights * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * address Sylvains commnets, few doc fixes * fix check_repo * align order of arguments * fix copies * fix labels * remove unnecessary mapping * fix saving tokenizer Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
b6b79faa7e
commit
d25e25ee2b
@ -319,6 +319,7 @@ AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Ch
|
||||
1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
|
||||
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/master/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
|
||||
1. **[XGLM](https://huggingface.co/docs/master/transformers/model_doc/xglm)** (From Facebook AI) released with the paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
|
||||
1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlm-prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
|
||||
|
@ -297,6 +297,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/master/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
|
||||
1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
|
||||
1. **[XGLM](https://huggingface.co/docs/master/transformers/model_doc/xglm)** (From Facebook AI) released with the paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
|
||||
1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlm-prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
|
||||
|
@ -321,6 +321,7 @@ conda install -c huggingface transformers
|
||||
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (来自 Facebook AI) 伴随论文 [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) 由 Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli 发布。
|
||||
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/master/transformers/model_doc/wav2vec2_phoneme)** (来自 Facebook AI) 伴随论文 [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) 由 Qiantong Xu, Alexei Baevski, Michael Auli 发布。
|
||||
1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
|
||||
1. **[XGLM](https://huggingface.co/docs/master/transformers/model_doc/xglm)** (From Facebook AI) released with the paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
|
||||
1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (来自 Facebook) 伴随论文 [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) 由 Guillaume Lample and Alexis Conneau 发布。
|
||||
1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlm-prophetnet)** (来自 Microsoft Research) 伴随论文 [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) 由 Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou 发布。
|
||||
1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (来自 Facebook AI), 伴随论文 [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) 由 Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov 发布。
|
||||
|
@ -333,6 +333,7 @@ conda install -c huggingface transformers
|
||||
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/master/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
|
||||
1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
|
||||
1. **[XGLM](https://huggingface.co/docs/master/transformers/model_doc/xglm)** (From Facebook AI) released with the paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
|
||||
1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlm-prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
|
||||
|
@ -304,6 +304,8 @@
|
||||
title: Wav2Vec2Phoneme
|
||||
- local: model_doc/wavlm
|
||||
title: WavLM
|
||||
- local: model_doc/xglm
|
||||
title: XGLM
|
||||
- local: model_doc/xlm
|
||||
title: XLM
|
||||
- local: model_doc/xlm-prophetnet
|
||||
|
@ -178,6 +178,7 @@ conversion utilities for the following models.
|
||||
1. **[WavLM](model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
|
||||
1. **[Wav2Vec2](model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/master/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
|
||||
1. **[XGLM](https://huggingface.co/docs/master/transformers/model_doc/xglm)** (From Facebook AI) released with the paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
|
||||
1. **[XLM](model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
1. **[XLM-ProphetNet](model_doc/xlm-prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
1. **[XLM-RoBERTa](model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
|
||||
@ -278,6 +279,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| ViTMAE | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
|
||||
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| XGLM | ✅ | ✅ | ✅ | ❌ | ✅ |
|
||||
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
75
docs/source/model_doc/xglm.mdx
Normal file
75
docs/source/model_doc/xglm.mdx
Normal file
@ -0,0 +1,75 @@
|
||||
<!--Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# XGLM
|
||||
|
||||
## Overview
|
||||
|
||||
The XGLM model was proposed in [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668)
|
||||
by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal,
|
||||
Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo,
|
||||
Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Large-scale autoregressive language models such as GPT-3 are few-shot learners that can perform a wide range of language
|
||||
tasks without fine-tuning. While these models are known to be able to jointly represent many different languages,
|
||||
their training data is dominated by English, potentially limiting their cross-lingual generalization.
|
||||
In this work, we train multilingual autoregressive language models on a balanced corpus covering a diverse set of languages,
|
||||
and study their few- and zero-shot learning capabilities in a wide range of tasks. Our largest model with 7.5 billion parameters
|
||||
sets new state of the art in few-shot learning in more than 20 representative languages, outperforming GPT-3 of comparable size
|
||||
in multilingual commonsense reasoning (with +7.4% absolute accuracy improvement in 0-shot settings and +9.4% in 4-shot settings)
|
||||
and natural language inference (+5.4% in each of 0-shot and 4-shot settings). On the FLORES-101 machine translation benchmark,
|
||||
our model outperforms GPT-3 on 171 out of 182 translation directions with 32 training examples, while surpassing the
|
||||
official supervised baseline in 45 directions. We present a detailed analysis of where the model succeeds and fails,
|
||||
showing in particular that it enables cross-lingual in-context learning on some tasks, while there is still room for improvement
|
||||
on surface form robustness and adaptation to tasks that do not have a natural cloze form. Finally, we evaluate our models
|
||||
in social value tasks such as hate speech detection in five languages and find it has limitations similar to comparable sized GPT-3 models.*
|
||||
|
||||
|
||||
This model was contributed by [Suraj](https://huggingface.co/valhalla). The original code can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/xglm).
|
||||
|
||||
## XGLMConfig
|
||||
|
||||
[[autodoc]] XGLMConfig
|
||||
|
||||
## XGLMTokenizer
|
||||
|
||||
[[autodoc]] XGLMTokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
- get_special_tokens_mask
|
||||
- create_token_type_ids_from_sequences
|
||||
- save_vocabulary
|
||||
|
||||
## XGLMTokenizerFast
|
||||
|
||||
[[autodoc]] XGLMTokenizerFast
|
||||
|
||||
## XGLMModel
|
||||
|
||||
[[autodoc]] XGLMModel
|
||||
- forward
|
||||
|
||||
## XGLMForCausalLM
|
||||
|
||||
[[autodoc]] XGLMForCausalLM
|
||||
- forward
|
||||
|
||||
## FlaxXGLMModel
|
||||
|
||||
[[autodoc]] FlaxXGLMModel
|
||||
- __call__
|
||||
|
||||
## FlaxXGLMForCausalLM
|
||||
|
||||
[[autodoc]] FlaxXGLMForCausalLM
|
||||
- __call__
|
@ -328,6 +328,7 @@ _import_structure = {
|
||||
"WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"WavLMConfig",
|
||||
],
|
||||
"models.xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"],
|
||||
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
|
||||
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||
@ -408,6 +409,7 @@ if is_sentencepiece_available():
|
||||
_import_structure["models.rembert"].append("RemBertTokenizer")
|
||||
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
|
||||
_import_structure["models.t5"].append("T5Tokenizer")
|
||||
_import_structure["models.xglm"].append("XGLMTokenizer")
|
||||
_import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer")
|
||||
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
|
||||
_import_structure["models.xlnet"].append("XLNetTokenizer")
|
||||
@ -422,6 +424,7 @@ else:
|
||||
if is_tokenizers_available():
|
||||
# Fast tokenizers
|
||||
_import_structure["models.realm"].append("RealmTokenizerFast")
|
||||
_import_structure["models.xglm"].append("XGLMTokenizerFast")
|
||||
_import_structure["models.fnet"].append("FNetTokenizerFast")
|
||||
_import_structure["models.roformer"].append("RoFormerTokenizerFast")
|
||||
_import_structure["models.clip"].append("CLIPTokenizerFast")
|
||||
@ -1461,6 +1464,14 @@ if is_torch_available():
|
||||
"WavLMPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.xglm"].extend(
|
||||
[
|
||||
"XGLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"XGLMForCausalLM",
|
||||
"XGLMModel",
|
||||
"XGLMPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.xlm"].extend(
|
||||
[
|
||||
"XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -2203,6 +2214,13 @@ if is_flax_available():
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.xglm"].extend(
|
||||
[
|
||||
"FlaxXGLMForCausalLM",
|
||||
"FlaxXGLMModel",
|
||||
"FlaxXGLMPreTrainedModel",
|
||||
]
|
||||
)
|
||||
else:
|
||||
from .utils import dummy_flax_objects
|
||||
|
||||
@ -2463,6 +2481,7 @@ if TYPE_CHECKING:
|
||||
from .models.wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizer
|
||||
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||
from .models.wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig
|
||||
from .models.xglm import XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XGLMConfig
|
||||
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
|
||||
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
|
||||
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||
@ -2544,6 +2563,7 @@ if TYPE_CHECKING:
|
||||
from .models.rembert import RemBertTokenizer
|
||||
from .models.speech_to_text import Speech2TextTokenizer
|
||||
from .models.t5 import T5Tokenizer
|
||||
from .models.xglm import XGLMTokenizer
|
||||
from .models.xlm_prophetnet import XLMProphetNetTokenizer
|
||||
from .models.xlm_roberta import XLMRobertaTokenizer
|
||||
from .models.xlnet import XLNetTokenizer
|
||||
@ -2591,6 +2611,7 @@ if TYPE_CHECKING:
|
||||
from .models.splinter import SplinterTokenizerFast
|
||||
from .models.squeezebert import SqueezeBertTokenizerFast
|
||||
from .models.t5 import T5TokenizerFast
|
||||
from .models.xglm import XGLMTokenizerFast
|
||||
from .models.xlm_roberta import XLMRobertaTokenizerFast
|
||||
from .models.xlnet import XLNetTokenizerFast
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
@ -3404,6 +3425,7 @@ if TYPE_CHECKING:
|
||||
WavLMModel,
|
||||
WavLMPreTrainedModel,
|
||||
)
|
||||
from .models.xglm import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMPreTrainedModel
|
||||
from .models.xlm import (
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
XLMForMultipleChoice,
|
||||
@ -4018,6 +4040,7 @@ if TYPE_CHECKING:
|
||||
FlaxWav2Vec2Model,
|
||||
FlaxWav2Vec2PreTrainedModel,
|
||||
)
|
||||
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
|
||||
else:
|
||||
# Import the same objects as dummies to get them in the namespace.
|
||||
# They will raise an import error if the user tries to instantiate / use them.
|
||||
|
@ -910,6 +910,35 @@ class BlenderbotConverter(Converter):
|
||||
return tokenizer
|
||||
|
||||
|
||||
class XGLMConverter(SpmConverter):
|
||||
def vocab(self, proto):
|
||||
vocab = [
|
||||
("<s>", 0.0),
|
||||
("<pad>", 0.0),
|
||||
("</s>", 0.0),
|
||||
("<unk>", 0.0),
|
||||
]
|
||||
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
||||
# fmt: off
|
||||
vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)]
|
||||
# fmt: on
|
||||
return vocab
|
||||
|
||||
def unk_id(self, proto):
|
||||
unk_id = 3
|
||||
return unk_id
|
||||
|
||||
def post_processor(self):
|
||||
return processors.TemplateProcessing(
|
||||
single="</s> $A",
|
||||
pair="</s> $A </s> </s> $B",
|
||||
special_tokens=[
|
||||
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
||||
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SLOW_TO_FAST_CONVERTERS = {
|
||||
"AlbertTokenizer": AlbertConverter,
|
||||
"BartTokenizer": RobertaConverter,
|
||||
@ -953,6 +982,7 @@ SLOW_TO_FAST_CONVERTERS = {
|
||||
"XLMRobertaTokenizer": XLMRobertaConverter,
|
||||
"XLNetTokenizer": XLNetConverter,
|
||||
"SplinterTokenizer": SplinterConverter,
|
||||
"XGLMTokenizer": XGLMConverter,
|
||||
}
|
||||
|
||||
|
||||
|
@ -116,6 +116,7 @@ from . import (
|
||||
wav2vec2_phoneme,
|
||||
wav2vec2_with_lm,
|
||||
wavlm,
|
||||
xglm,
|
||||
xlm,
|
||||
xlm_prophetnet,
|
||||
xlm_roberta,
|
||||
|
@ -36,6 +36,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("vit_mae", "ViTMAEConfig"),
|
||||
("realm", "RealmConfig"),
|
||||
("nystromformer", "NystromformerConfig"),
|
||||
("xglm", "XGLMConfig"),
|
||||
("imagegpt", "ImageGPTConfig"),
|
||||
("qdqbert", "QDQBertConfig"),
|
||||
("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
|
||||
@ -128,6 +129,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xglm", "XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@ -208,6 +210,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("vit_mae", "ViTMAE"),
|
||||
("realm", "Realm"),
|
||||
("nystromformer", "Nystromformer"),
|
||||
("xglm", "XGLM"),
|
||||
("imagegpt", "ImageGPT"),
|
||||
("qdqbert", "QDQBert"),
|
||||
("vision-encoder-decoder", "Vision Encoder decoder"),
|
||||
|
@ -33,6 +33,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("vilt", "ViltModel"),
|
||||
("vit_mae", "ViTMAEModel"),
|
||||
("nystromformer", "NystromformerModel"),
|
||||
("xglm", "XGLMModel"),
|
||||
("imagegpt", "ImageGPTModel"),
|
||||
("qdqbert", "QDQBertModel"),
|
||||
("fnet", "FNetModel"),
|
||||
@ -209,6 +210,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
("xglm", "XGLMForCausalLM"),
|
||||
("qdqbert", "QDQBertLMHeadModel"),
|
||||
("trocr", "TrOCRForCausalLM"),
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
|
@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
|
||||
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("xglm", "FlaxXGLMModel"),
|
||||
("blenderbot-small", "FlaxBlenderbotSmallModel"),
|
||||
("pegasus", "FlaxPegasusModel"),
|
||||
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
|
||||
@ -121,6 +122,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("gpt2", "FlaxGPT2LMHeadModel"),
|
||||
("gpt_neo", "FlaxGPTNeoForCausalLM"),
|
||||
("gptj", "FlaxGPTJForCausalLM"),
|
||||
("xglm", "FlaxXGLMForCausalLM"),
|
||||
]
|
||||
)
|
||||
|
||||
|
66
src/transformers/models/xglm/__init__.py
Normal file
66
src/transformers/models/xglm/__init__.py
Normal file
@ -0,0 +1,66 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# rely on isort to merge the imports
|
||||
from ...file_utils import _LazyModule, is_flax_available, is_tokenizers_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"],
|
||||
"tokenization_xglm": ["XGLMTokenizer"],
|
||||
}
|
||||
|
||||
if is_tokenizers_available():
|
||||
_import_structure["tokenization_xglm_fast"] = ["XGLMTokenizerFast"]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_xglm"] = [
|
||||
"XGLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"XGLMForCausalLM",
|
||||
"XGLMModel",
|
||||
"XGLMPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_xglm"] = [
|
||||
"FlaxXGLMForCausalLM",
|
||||
"FlaxXGLMModel",
|
||||
"FlaxXGLMPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_xglm import XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XGLMConfig
|
||||
from .tokenization_xglm import XGLMTokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
from .tokenization_xglm_fast import XGLMTokenizerFast
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_xglm import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMPreTrainedModel
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
140
src/transformers/models/xglm/configuration_xglm.py
Normal file
140
src/transformers/models/xglm/configuration_xglm.py
Normal file
@ -0,0 +1,140 @@
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" XGLM model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/xglm-564M": "https://huggingface.co/facebook/xglm-564M/resolve/main/config.json",
|
||||
# See all XGLM models at https://huggingface.co/models?filter=xglm
|
||||
}
|
||||
|
||||
|
||||
class XGLMConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`XGLMModel`]. It is used to instantiate an XGLM
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the XGLM
|
||||
[facebook/xglm-564M](https://huggingface.co/facebook/xglm-564M) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256008):
|
||||
Vocabulary size of the XGLM model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`XGLMModel`] or [`FlaxXGLMModel`].
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
d_model (`int`, *optional*, defaults to 1024):
|
||||
Dimension of the layers and the pooler layer.
|
||||
ffn_dim (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
num_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers Transformer decoder.
|
||||
attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, dencoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
layerdrop: (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||
for more details.
|
||||
init_std (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
scale_embedding (`bool`, *optional*, defaults to `True`):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import XGLMModel, XGLMConfig
|
||||
|
||||
>>> # Initializing a XGLM facebook/xglm-564M style configuration
|
||||
>>> configuration = XGLMConfig()
|
||||
|
||||
>>> # Initializing a model from the facebook/xglm-564M style configuration
|
||||
>>> model = XGLMModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "xglm"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
attribute_map = {
|
||||
"num_attention_heads": "attention_heads",
|
||||
"hidden_size": "d_model",
|
||||
"num_hidden_layers": "num_layers",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256008,
|
||||
max_position_embeddings=2048,
|
||||
d_model=1024,
|
||||
ffn_dim=4096,
|
||||
num_layers=24,
|
||||
attention_heads=16,
|
||||
activation_function="gelu",
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
activation_dropout=0.0,
|
||||
layerdrop=0.0,
|
||||
init_std=0.02,
|
||||
scale_embedding=True,
|
||||
use_cache=True,
|
||||
decoder_start_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.d_model = d_model
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_layers = num_layers
|
||||
self.attention_heads = attention_heads
|
||||
self.activation_function = activation_function
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.layerdrop = layerdrop
|
||||
self.init_std = init_std
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
self.use_cache = use_cache
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
**kwargs,
|
||||
)
|
802
src/transformers/models/xglm/modeling_flax_xglm.py
Normal file
802
src/transformers/models/xglm/modeling_flax_xglm.py
Normal file
@ -0,0 +1,802 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Flax XGLM model."""
|
||||
|
||||
|
||||
import math
|
||||
import random
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
||||
from ...utils import logging
|
||||
from .configuration_xglm import XGLMConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "facebook/xglm-564M"
|
||||
_CONFIG_FOR_DOC = "XGLMConfig"
|
||||
_TOKENIZER_FOR_DOC = "XGLMTokenizer"
|
||||
|
||||
XGLM_START_DOCSTRING = r"""
|
||||
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a Flax Linen
|
||||
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
||||
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
|
||||
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
||||
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
||||
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
||||
|
||||
Parameters:
|
||||
config ([`XGLMConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
||||
`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given `dtype`.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
||||
[`~FlaxPreTrainedModel.to_bf16`].
|
||||
"""
|
||||
|
||||
XGLM_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`~XGLMTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
def create_sinusoidal_positions(n_pos, dim, padding_idx=1):
|
||||
half_dim = dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = np.exp(np.arange(half_dim) * -emb)
|
||||
emb = np.expand_dims(np.arange(n_pos), 1) * np.expand_dims(emb, 0)
|
||||
emb = np.concatenate([np.sin(emb), np.cos(emb)], 1)
|
||||
emb = np.reshape(emb, (n_pos, dim))
|
||||
|
||||
if padding_idx is not None:
|
||||
emb[padding_idx, :] = 0
|
||||
|
||||
return jnp.array(emb)
|
||||
|
||||
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
|
||||
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
class FlaxXGLMAttention(nn.Module):
|
||||
config: XGLMConfig
|
||||
embed_dim: int
|
||||
num_heads: int
|
||||
dropout: float = 0.0
|
||||
causal: bool = False
|
||||
bias: bool = True
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self) -> None:
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} "
|
||||
"and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
dense = partial(
|
||||
nn.Dense,
|
||||
self.embed_dim,
|
||||
use_bias=self.bias,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
||||
self.out_proj = dense()
|
||||
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
if self.causal:
|
||||
self.causal_mask = make_causal_mask(
|
||||
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
||||
)
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
||||
|
||||
@nn.compact
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
is_initialized = self.has_variable("cache", "cached_key")
|
||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
||||
|
||||
if is_initialized:
|
||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
||||
# update key, value caches with our new 1d spatial slices
|
||||
cur_index = cache_index.value
|
||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
||||
cached_key.value = key
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend
|
||||
# to those key positions that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
)
|
||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
||||
return key, value, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: jnp.ndarray,
|
||||
key_value_states: Optional[jnp.ndarray] = None,
|
||||
attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> Tuple[jnp.ndarray]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states)
|
||||
# get key, value proj
|
||||
if is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states)
|
||||
value_states = self.v_proj(key_value_states)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = self._split_heads(query_states)
|
||||
key_states = self._split_heads(key_states)
|
||||
value_states = self._split_heads(value_states)
|
||||
|
||||
# handle cache prepare causal attention mask
|
||||
if self.causal:
|
||||
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
||||
if self.has_variable("cache", "cached_key"):
|
||||
mask_shift = self.variables["cache"]["cache_index"]
|
||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
||||
causal_mask = lax.dynamic_slice(
|
||||
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
||||
)
|
||||
else:
|
||||
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
||||
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
||||
|
||||
# combine masks if needed
|
||||
if attention_mask is not None and self.causal:
|
||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
||||
attention_mask = combine_masks(attention_mask, causal_mask)
|
||||
elif self.causal:
|
||||
attention_mask = causal_mask
|
||||
elif attention_mask is not None:
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
|
||||
# During fast autoregressive decoding, we feed one position at a time,
|
||||
# and cache the keys and values step by step.
|
||||
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
||||
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
||||
key_states, value_states, query_states, attention_mask
|
||||
)
|
||||
|
||||
# Convert the boolean attention mask to an attention bias.
|
||||
if attention_mask is not None:
|
||||
# attention mask in the form of attention bias
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
||||
dropout_rng = None
|
||||
if not deterministic and self.dropout > 0.0:
|
||||
dropout_rng = self.make_rng("dropout")
|
||||
|
||||
attn_weights = dot_product_attention_weights(
|
||||
query_states,
|
||||
key_states,
|
||||
bias=attention_bias,
|
||||
dropout_rng=dropout_rng,
|
||||
dropout_rate=self.dropout,
|
||||
broadcast_dropout=True,
|
||||
deterministic=deterministic,
|
||||
dtype=self.dtype,
|
||||
precision=None,
|
||||
)
|
||||
|
||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
||||
attn_output = self._merge_heads(attn_output)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class FlaxXGLMDecoderLayer(nn.Module):
|
||||
config: XGLMConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self) -> None:
|
||||
self.embed_dim = self.config.d_model
|
||||
self.self_attn = FlaxXGLMAttention(
|
||||
config=self.config,
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
causal=True,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
self.activation_fn = ACT2FN[self.config.activation_function]
|
||||
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
|
||||
|
||||
if self.config.add_cross_attention:
|
||||
self.encoder_attn = FlaxXGLMAttention(
|
||||
config=self.config,
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer.__call__
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = True,
|
||||
deterministic: bool = True,
|
||||
) -> Tuple[jnp.ndarray]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
|
||||
)
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Cross-Attention Block
|
||||
cross_attn_weights = None
|
||||
if encoder_hidden_states is not None:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
hidden_states, cross_attn_weights = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights, cross_attn_weights)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxXGLMDecoderLayerCollection(nn.Module):
|
||||
config: XGLMConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxXGLMDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_layers)
|
||||
]
|
||||
self.layerdrop = self.config.layerdrop
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if not deterministic and (dropout_probability < self.layerdrop):
|
||||
layer_outputs = (None, None, None)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
outputs = (hidden_states, all_hidden_states, all_self_attns, all_cross_attentions)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
class FlaxXGLMModule(nn.Module):
|
||||
config: XGLMConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
|
||||
embed_dim = self.config.d_model
|
||||
self.padding_idx = self.config.pad_token_id
|
||||
self.max_target_positions = self.config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
# XGLM is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models don't have this hack
|
||||
self.offset = 2
|
||||
self.embed_positions = create_sinusoidal_positions(
|
||||
self.config.max_position_embeddings + self.offset, embed_dim
|
||||
)
|
||||
self.layers = FlaxXGLMDecoderLayerCollection(self.config, self.dtype)
|
||||
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
deterministic: bool = True,
|
||||
):
|
||||
input_shape = input_ids.shape
|
||||
input_ids = input_ids.reshape(-1, input_shape[-1])
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
# embed positions
|
||||
position_ids = position_ids + self.offset
|
||||
positions = jnp.take(self.embed_positions, position_ids, axis=0)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
|
||||
outputs = self.layers(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_states = outputs[0]
|
||||
last_hidden_states = self.layer_norm(last_hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (last_hidden_states,) + outputs[1:]
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=last_hidden_states,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel):
|
||||
config_class = XGLMConfig
|
||||
base_model_prefix: str = "model"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: XGLMConfig,
|
||||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
if self.config.add_cross_attention:
|
||||
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
|
||||
encoder_attention_mask = attention_mask
|
||||
module_init_outputs = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
||||
|
||||
return module_init_outputs["params"]
|
||||
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
@add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: jnp.ndarray,
|
||||
attention_mask: Optional[jnp.ndarray] = None,
|
||||
position_ids: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
train: bool = False,
|
||||
params: dict = None,
|
||||
past_key_values: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is None:
|
||||
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
||||
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
||||
|
||||
# prepare encoder inputs
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
if position_ids is None:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
||||
|
||||
inputs = {"params": params or self.params}
|
||||
|
||||
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
||||
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
||||
# changed by FlaxXGLMAttention module
|
||||
if past_key_values:
|
||||
inputs["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
input_ids=jnp.array(input_ids, dtype="i4"),
|
||||
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
deterministic=not train,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XGLM_START_DOCSTRING,
|
||||
)
|
||||
class FlaxXGLMModel(FlaxXGLMPreTrainedModel):
|
||||
module_class = FlaxXGLMModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxXGLMModel,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
class FlaxXGLMForCausalLMModule(nn.Module):
|
||||
config: XGLMConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.model = FlaxXGLMModule(self.config, self.dtype)
|
||||
self.lm_head = nn.Dense(
|
||||
self.config.vocab_size,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
deterministic: bool = True,
|
||||
):
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.model.variables["params"]["embed_tokens"]["embedding"]
|
||||
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
||||
else:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (lm_logits,) + outputs[1:]
|
||||
|
||||
return FlaxCausalLMOutputWithCrossAttentions(
|
||||
logits=lm_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
embeddings).
|
||||
""",
|
||||
XGLM_START_DOCSTRING,
|
||||
)
|
||||
class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel):
|
||||
module_class = FlaxXGLMForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length)
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
||||
# But since GPT2 uses a causal mask, those positions are masked anyways.
|
||||
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": extended_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
||||
return model_kwargs
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxXGLMForCausalLM,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
943
src/transformers/models/xglm/modeling_xglm.py
Executable file
943
src/transformers/models/xglm/modeling_xglm.py
Executable file
@ -0,0 +1,943 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Fairseq Authors The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch XGLM model."""
|
||||
|
||||
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_xglm import XGLMConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "facebook/xglm-564M"
|
||||
_CONFIG_FOR_DOC = "XGLMConfig"
|
||||
_TOKENIZER_FOR_DOC = "XGLMTokenizer"
|
||||
|
||||
|
||||
XGLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/xglm-564M",
|
||||
# See all XGLM models at https://huggingface.co/models?filter=xglm
|
||||
]
|
||||
|
||||
XGLM_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`XGLMConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
XGLM_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`XGLMTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
``input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
|
||||
`(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
|
||||
can choose to directly pass an embedded representation. This is useful if you want more control over how to
|
||||
convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. If
|
||||
`past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
||||
`past_key_values`). This is useful if you want more control over how to convert `input_ids` indices into
|
||||
associated vectors than the model's internal embedding lookup matrix.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
|
||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding with M2M100->XGLM
|
||||
class XGLMSinusoidalPositionalEmbedding(nn.Module):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.offset = 2
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
||||
|
||||
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
|
||||
if hasattr(self, "weights"):
|
||||
# in forward put the weights on the correct dtype and device of the param
|
||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||
|
||||
self.weights = nn.Parameter(emb_weights)
|
||||
self.weights.requires_grad = False
|
||||
self.weights.detach_()
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
"""
|
||||
Build sinusoidal embeddings.
|
||||
|
||||
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
|
||||
"Attention Is All You Need".
|
||||
"""
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
||||
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
||||
if embedding_dim % 2 == 1:
|
||||
# zero pad
|
||||
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
||||
if padding_idx is not None:
|
||||
emb[padding_idx, :] = 0
|
||||
|
||||
return emb
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
|
||||
):
|
||||
if input_ids is not None:
|
||||
bsz, seq_len = input_ids.size()
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
|
||||
input_ids.device
|
||||
)
|
||||
else:
|
||||
bsz, seq_len = inputs_embeds.size()[:-1]
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
# expand embeddings if needed
|
||||
max_pos = self.padding_idx + 1 + seq_len
|
||||
if max_pos > self.weights.size(0):
|
||||
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
|
||||
|
||||
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
||||
|
||||
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
||||
"""
|
||||
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
||||
|
||||
Args:
|
||||
inputs_embeds: torch.Tensor
|
||||
|
||||
Returns: torch.Tensor
|
||||
"""
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
sequence_length = input_shape[1]
|
||||
|
||||
position_ids = torch.arange(
|
||||
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
||||
)
|
||||
return position_ids.unsqueeze(0).expand(input_shape).contiguous()
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->XGLM
|
||||
class XGLMAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
is_decoder: bool = False,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned aross GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class XGLMDecoderLayer(nn.Module):
|
||||
def __init__(self, config: XGLMConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
self.self_attn = XGLMAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
if config.add_cross_attention:
|
||||
self.crossattention = XGLMAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
|
||||
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
cross attention input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
*(encoder_attention_heads,)*.
|
||||
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size *(decoder_attention_heads,)*.
|
||||
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Cross-Attention Block
|
||||
cross_attn_present_key_value = None
|
||||
cross_attn_weights = None
|
||||
if encoder_hidden_states is not None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights, cross_attn_weights)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class XGLMPreTrainedModel(PreTrainedModel):
|
||||
config_class = XGLMConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, XGLMModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XGLM_START_DOCSTRING,
|
||||
)
|
||||
class XGLMModel(XGLMPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_layers* layers. Each layer is a [`XGLMDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: XGLMConfig
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
|
||||
self.embed_positions = XGLMSinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
config.pad_token_id,
|
||||
)
|
||||
self.layers = nn.ModuleList([XGLMDecoderLayer(config) for _ in range(config.num_layers)])
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
@add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using [`~XGLMTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
of the decoder.
|
||||
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
|
||||
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
|
||||
selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||||
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
|
||||
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
||||
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
||||
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
||||
embedding lookup matrix.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, use_cache)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
embeddings).
|
||||
""",
|
||||
XGLM_START_DOCSTRING,
|
||||
)
|
||||
class XGLMForCausalLM(XGLMPreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"model.embed_positions.weights",
|
||||
r"lm_head.weight",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
r"model.embed_positions.weights",
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = XGLMModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutputWithCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs[0])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# shift labels and add a pad token to the end
|
||||
shift_labels = labels.new_zeros(labels.shape)
|
||||
shift_labels[:, :-1] = labels[:, 1:].clone()
|
||||
shift_labels[:, -1] = self.config.pad_token_id
|
||||
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
313
src/transformers/models/xglm/tokenization_xglm.py
Normal file
313
src/transformers/models/xglm/tokenization_xglm.py
Normal file
@ -0,0 +1,313 @@
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for ."""
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SPIECE_UNDERLINE = "▁"
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/xglm-564M": "https://huggingface.co/facebook/xglm-564M/resolve/main/sentencepiece.bpe.model",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/xglm-564M": 1024,
|
||||
}
|
||||
|
||||
|
||||
class XGLMTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
|
||||
[SentencePiece](https://github.com/google/sentencepiece).
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
||||
this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||
|
||||
<Tip>
|
||||
|
||||
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
||||
sequence. The token used is the `cls_token`.
|
||||
|
||||
</Tip>
|
||||
|
||||
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The end of sequence token.
|
||||
|
||||
<Tip>
|
||||
|
||||
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
||||
The token used is the `sep_token`.
|
||||
|
||||
</Tip>
|
||||
|
||||
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||
token of a sequence built with special tokens.
|
||||
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
|
||||
Additional special tokens used by the tokenizer.
|
||||
sp_model_kwargs (`dict`, *optional*):
|
||||
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
||||
to set:
|
||||
|
||||
- `enable_sampling`: Enable subword regularization.
|
||||
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
||||
|
||||
- `nbest_size = {0,1}`: No sampling is performed.
|
||||
- `nbest_size > 1`: samples from the nbest_size results.
|
||||
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
||||
using forward-filtering-and-backward-sampling algorithm.
|
||||
|
||||
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
||||
BPE-dropout.
|
||||
|
||||
Attributes:
|
||||
sp_model (`SentencePieceProcessor`):
|
||||
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
sep_token="</s>",
|
||||
cls_token="<s>",
|
||||
unk_token="<unk>",
|
||||
pad_token="<pad>",
|
||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||
|
||||
# Compatibility with the original tokenizer
|
||||
self.num_madeup_words = 7
|
||||
madeup_words = [f"<madeupword{i}>" for i in range(self.num_madeup_words)]
|
||||
|
||||
kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", [])
|
||||
kwargs["additional_special_tokens"] += [
|
||||
word for word in madeup_words if word not in kwargs["additional_special_tokens"]
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
sep_token=sep_token,
|
||||
cls_token=cls_token,
|
||||
pad_token=pad_token,
|
||||
sp_model_kwargs=self.sp_model_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(str(vocab_file))
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
# Original fairseq vocab and spm vocab must be "aligned":
|
||||
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
|
||||
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
|
||||
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
|
||||
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
|
||||
|
||||
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
|
||||
self.fairseq_offset = 1
|
||||
|
||||
# Mimic fairseq token-to-id alignment for the first 4 token
|
||||
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
|
||||
|
||||
sp_size = len(self.sp_model)
|
||||
madeup_words = {f"<madeupword{i}>": sp_size + i + self.fairseq_offset for i in range(self.num_madeup_words)}
|
||||
self.fairseq_tokens_to_ids.update(madeup_words)
|
||||
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
||||
return state
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__ = d
|
||||
|
||||
# for backward compatibility
|
||||
if not hasattr(self, "sp_model_kwargs"):
|
||||
self.sp_model_kwargs = {}
|
||||
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. An XLM-RoBERTa sequence has the following format:
|
||||
|
||||
- single sequence: `<s> X </s>`
|
||||
- pair of sequences: `<s> A </s></s> B </s>`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [self.sep_token_id] + token_ids_0
|
||||
sep = [self.sep_token_id]
|
||||
return sep + token_ids_0 + sep + sep + token_ids_1
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||
)
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0))
|
||||
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1))
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
|
||||
not make use of token type ids, therefore a list of zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of zeros.
|
||||
|
||||
"""
|
||||
|
||||
sep = [self.sep_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(sep + token_ids_0) * [0]
|
||||
return len(sep + token_ids_0 + sep + sep + token_ids_1) * [0]
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.sp_model) + self.fairseq_offset + self.num_madeup_words
|
||||
|
||||
def get_vocab(self):
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
if token in self.fairseq_tokens_to_ids:
|
||||
return self.fairseq_tokens_to_ids[token]
|
||||
spm_id = self.sp_model.PieceToId(token)
|
||||
|
||||
# Need to return unknown token if the SP model returned 0
|
||||
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if index in self.fairseq_ids_to_tokens:
|
||||
return self.fairseq_ids_to_tokens[index]
|
||||
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
206
src/transformers/models/xglm/tokenization_xglm_fast.py
Normal file
206
src/transformers/models/xglm/tokenization_xglm_fast.py
Normal file
@ -0,0 +1,206 @@
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for XGLM."""
|
||||
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ...file_utils import is_sentencepiece_available
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .tokenization_xglm import XGLMTokenizer
|
||||
else:
|
||||
XGLMTokenizer = None
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/xglm-564M": "https://huggingface.co/facebook/xglm-564M/resolve/main/sentencepiece.bpe.model",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"facebook/xglm-564M": "https://huggingface.co/facebook/xglm-564M/resolve/main/tokenizer.json",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/xglm-564M": 1024,
|
||||
}
|
||||
|
||||
|
||||
class XGLMTokenizerFast(PreTrainedTokenizerFast):
|
||||
"""
|
||||
Construct a "fast" XGLM tokenizer (backed by HuggingFace's *tokenizers* library). Adapted from [`RobertaTokenizer`]
|
||||
and [`XLNetTokenizer`]. Based on
|
||||
[BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
||||
refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||
|
||||
<Tip>
|
||||
|
||||
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
||||
sequence. The token used is the `cls_token`.
|
||||
|
||||
</Tip>
|
||||
|
||||
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The end of sequence token.
|
||||
|
||||
<Tip>
|
||||
|
||||
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
||||
The token used is the `sep_token`.
|
||||
|
||||
</Tip>
|
||||
|
||||
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||
token of a sequence built with special tokens.
|
||||
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
|
||||
Additional special tokens used by the tokenizer.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = XGLMTokenizer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
sep_token="</s>",
|
||||
cls_token="<s>",
|
||||
unk_token="<unk>",
|
||||
pad_token="<pad>",
|
||||
**kwargs
|
||||
):
|
||||
# Compatibility with the original tokenizer
|
||||
self.num_madeup_words = 7
|
||||
madeup_words = [f"<madeupword{i}>" for i in range(self.num_madeup_words)]
|
||||
|
||||
kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", [])
|
||||
kwargs["additional_special_tokens"] += [
|
||||
word for word in madeup_words if word not in kwargs["additional_special_tokens"]
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
vocab_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
sep_token=sep_token,
|
||||
cls_token=cls_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. An XLM-RoBERTa sequence has the following format:
|
||||
|
||||
- single sequence: `<s> X </s>`
|
||||
- pair of sequences: `<s> A </s></s> B </s>`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [self.sep_token_id] + token_ids_0
|
||||
sep = [self.sep_token_id]
|
||||
return sep + token_ids_0 + sep + sep + token_ids_1
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
|
||||
not make use of token type ids, therefore a list of zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of zeros.
|
||||
|
||||
"""
|
||||
|
||||
sep = [self.sep_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(sep + token_ids_0) * [0]
|
||||
return len(sep + token_ids_0 + sep + sep + token_ids_1) * [0]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
@ -961,3 +961,24 @@ class FlaxWav2Vec2PreTrainedModel(metaclass=DummyObject):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxXGLMForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxXGLMModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxXGLMPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
@ -3848,6 +3848,30 @@ class WavLMPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
XGLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class XGLMForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class XGLMModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class XGLMPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -136,6 +136,13 @@ class T5Tokenizer(metaclass=DummyObject):
|
||||
requires_backends(self, ["sentencepiece"])
|
||||
|
||||
|
||||
class XGLMTokenizer(metaclass=DummyObject):
|
||||
_backends = ["sentencepiece"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["sentencepiece"])
|
||||
|
||||
|
||||
class XLMProphetNetTokenizer(metaclass=DummyObject):
|
||||
_backends = ["sentencepiece"]
|
||||
|
||||
|
@ -297,6 +297,13 @@ class T5TokenizerFast(metaclass=DummyObject):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
|
||||
class XGLMTokenizerFast(metaclass=DummyObject):
|
||||
_backends = ["tokenizers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
|
||||
class XLMRobertaTokenizerFast(metaclass=DummyObject):
|
||||
_backends = ["tokenizers"]
|
||||
|
||||
|
347
tests/test_modeling_flax_xglm.py
Normal file
347
tests/test_modeling_flax_xglm.py
Normal file
@ -0,0 +1,347 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import transformers
|
||||
from transformers import XGLMConfig, XGLMTokenizer, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_sentencepiece, slow
|
||||
|
||||
from .test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.models.xglm.modeling_flax_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxXGLMModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
d_model=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
ffn_dim=37,
|
||||
activation_function="gelu",
|
||||
activation_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = d_model
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.ffn_dim = ffn_dim
|
||||
self.activation_function = activation_function
|
||||
self.activation_dropout = activation_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = None
|
||||
self.bos_token_id = 0
|
||||
self.eos_token_id = 2
|
||||
self.pad_token_id = 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length], self.vocab_size), 3, self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
config = XGLMConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
num_layers=self.num_hidden_layers,
|
||||
attention_heads=self.num_attention_heads,
|
||||
ffn_dim=self.ffn_dim,
|
||||
activation_function=self.activation_function,
|
||||
activation_dropout=self.activation_dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
use_cache=True,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
return (config, input_ids, input_mask)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
config, input_ids, attention_mask = self.prepare_config_and_inputs()
|
||||
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
|
||||
max_decoder_length = 20
|
||||
model = model_class_name(config)
|
||||
|
||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
|
||||
|
||||
position_ids = jnp.broadcast_to(
|
||||
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
|
||||
)
|
||||
outputs_cache = model(
|
||||
input_ids[:, :-1],
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
|
||||
outputs_cache_next = model(
|
||||
input_ids[:, -1:],
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=outputs_cache.past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
outputs = model(input_ids)
|
||||
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
|
||||
max_decoder_length = 20
|
||||
model = model_class_name(config)
|
||||
|
||||
attention_mask_cache = jnp.concatenate(
|
||||
[attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||
position_ids = jnp.broadcast_to(
|
||||
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
|
||||
)
|
||||
|
||||
outputs_cache = model(
|
||||
input_ids[:, :-1],
|
||||
attention_mask=attention_mask_cache,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
|
||||
outputs_cache_next = model(
|
||||
input_ids[:, -1:],
|
||||
past_key_values=outputs_cache.past_key_values,
|
||||
attention_mask=attention_mask_cache,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
outputs = model(input_ids, attention_mask=attention_mask)
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_flax
|
||||
class FlaxXGLMModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxXGLMModel, FlaxXGLMForCausalLM) if is_flax_available() else ()
|
||||
all_generative_model_classes = (FlaxXGLMForCausalLM,) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxXGLMModelTester(self)
|
||||
|
||||
def test_use_cache_forward(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)
|
||||
|
||||
def test_use_cache_forward_with_attn_mask(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_use_cache_forward_with_attn_mask(
|
||||
model_class_name, config, input_ids, attention_mask
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
tokenizer = XGLMTokenizer.from_pretrained("XGLM", padding_side="left")
|
||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
|
||||
|
||||
model = FlaxXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
||||
model.config.num_beams = 1
|
||||
model.config.do_sample = False
|
||||
|
||||
jit_generate = jax.jit(model.generate)
|
||||
|
||||
output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences
|
||||
|
||||
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
|
||||
|
||||
expected_string = [
|
||||
"Hello this is a long string of questions, but I'm not sure if I'm",
|
||||
"Hey, I'm a newbie to the forum and I'",
|
||||
]
|
||||
|
||||
self.assertListEqual(output_string, expected_string)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
pt_model = pt_model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model.config.use_cache = False
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("facebook/xglm-564M")
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
471
tests/test_modeling_xglm.py
Normal file
471
tests/test_modeling_xglm.py
Normal file
@ -0,0 +1,471 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import datetime
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from transformers import XGLMConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
|
||||
|
||||
|
||||
class XGLMModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
d_model=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
ffn_dim=37,
|
||||
activation_function="gelu",
|
||||
activation_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = d_model
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.ffn_dim = ffn_dim
|
||||
self.activation_function = activation_function
|
||||
self.activation_dropout = activation_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = None
|
||||
self.bos_token_id = 0
|
||||
self.eos_token_id = 2
|
||||
self.pad_token_id = 1
|
||||
|
||||
def get_large_model_config(self):
|
||||
return XGLMConfig.from_pretrained("facebook/xglm-564M")
|
||||
|
||||
def prepare_config_and_inputs(
|
||||
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
||||
):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(3)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
)
|
||||
|
||||
def get_config(
|
||||
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
||||
):
|
||||
return XGLMConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
num_layers=self.num_hidden_layers,
|
||||
attention_heads=self.num_attention_heads,
|
||||
ffn_dim=self.ffn_dim,
|
||||
activation_function=self.activation_function,
|
||||
activation_dropout=self.activation_dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
use_cache=True,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_xglm_model(self, config, input_ids, input_mask, head_mask, *args):
|
||||
model = XGLMModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, head_mask=head_mask)
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(result.past_key_values), config.num_hidden_layers)
|
||||
|
||||
def create_and_check_xglm_model_past(self, config, input_ids, input_mask, head_mask, *args):
|
||||
model = XGLMModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
outputs_no_past = model(input_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past = outputs.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# append to next input_ids and token_type_ids
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_xglm_model_attention_mask_past(self, config, input_ids, input_mask, head_mask, *args):
|
||||
model = XGLMModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# create attention mask
|
||||
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
half_seq_length = self.seq_length // 2
|
||||
attn_mask[:, half_seq_length:] = 0
|
||||
|
||||
# first forward pass
|
||||
output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, torch.zeros((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_xglm_model_past_large_inputs(self, config, input_ids, input_mask, head_mask, *args):
|
||||
model = XGLMModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
|
||||
|
||||
output, past = outputs.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=1)
|
||||
|
||||
# append to next input_ids
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, *args):
|
||||
model = XGLMForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_forward_and_backwards(
|
||||
self, config, input_ids, input_mask, head_mask, *args, gradient_checkpointing=False
|
||||
):
|
||||
model = XGLMForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
if gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
result = model(input_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
result.loss.backward()
|
||||
|
||||
def create_and_check_xglm_weight_initialization(self, config, *args):
|
||||
model = XGLMModel(config)
|
||||
model_std = model.config.initializer_range / math.sqrt(2 * model.config.num_hidden_layers)
|
||||
for key in model.state_dict().keys():
|
||||
if "c_proj" in key and "weight" in key:
|
||||
self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
|
||||
self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else ()
|
||||
test_missing_keys = False
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = XGLMModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=XGLMConfig, n_embd=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_xglm_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xglm_model(*config_and_inputs)
|
||||
|
||||
def test_xglm_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xglm_model_past(*config_and_inputs)
|
||||
|
||||
def test_xglm_model_att_mask_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xglm_model_attention_mask_past(*config_and_inputs)
|
||||
|
||||
def test_xglm_model_past_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xglm_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def test_xglm_lm_head_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_xglm_gradient_checkpointing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
|
||||
|
||||
def test_xglm_weight_initialization(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
||||
model.to(torch_device)
|
||||
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# use different length sentences to test batching
|
||||
sentences = [
|
||||
"Hello, my dog is a little",
|
||||
"Today, I",
|
||||
]
|
||||
|
||||
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||||
input_ids = inputs["input_ids"].to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||
)
|
||||
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||
|
||||
expected_output_sentence = [
|
||||
"Hello, my dog is a little bit of a shy one, but he is very friendly",
|
||||
"Today, I am going to share with you a few of my favorite things",
|
||||
]
|
||||
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in XGLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = XGLMModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class XGLMModelLanguageGenerationTest(unittest.TestCase):
|
||||
def _test_lm_generate_xglm_helper(
|
||||
self,
|
||||
gradient_checkpointing=False,
|
||||
verify_outputs=True,
|
||||
):
|
||||
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
||||
if gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
else:
|
||||
model.gradient_checkpointing_disable()
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[2, 268, 9865]], dtype=torch.long, device=torch_device) # The dog
|
||||
# </s> The dog is a very friendly dog. He is very affectionate and loves to play with other
|
||||
# fmt: off
|
||||
expected_output_ids = [2, 268, 9865, 67, 11, 1988, 57252, 9865, 5, 984, 67, 1988, 213838, 1658, 53, 70446, 33, 6657, 278, 1581]
|
||||
# fmt: on
|
||||
output_ids = model.generate(input_ids, do_sample=False, num_beams=1)
|
||||
if verify_outputs:
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_xglm(self):
|
||||
self._test_lm_generate_xglm_helper()
|
||||
|
||||
@slow
|
||||
def test_lm_generate_xglm_with_gradient_checkpointing(self):
|
||||
self._test_lm_generate_xglm_helper(gradient_checkpointing=True)
|
||||
|
||||
@slow
|
||||
def test_xglm_sample(self):
|
||||
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
|
||||
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
|
||||
input_ids = tokenized.input_ids.to(torch_device)
|
||||
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
|
||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUT_STR = "Today is a nice day and I am happy to show you all about a recent project for my"
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
||||
|
||||
@slow
|
||||
def test_xglm_sample_max_time(self):
|
||||
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
|
||||
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
|
||||
input_ids = tokenized.input_ids.to(torch_device)
|
||||
|
||||
MAX_TIME = 0.15
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=1.25 * MAX_TIME))
|
203
tests/test_tokenization_xglm.py
Normal file
203
tests/test_tokenization_xglm.py
Normal file
@ -0,0 +1,203 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, XGLMTokenizer, XGLMTokenizerFast
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class XGLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = XGLMTokenizer
|
||||
rust_tokenizer_class = XGLMTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
test_sentencepiece = True
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = XGLMTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def test_convert_token_and_id(self):
|
||||
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
|
||||
token = "<pad>"
|
||||
token_id = 1
|
||||
|
||||
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
|
||||
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
|
||||
|
||||
def test_get_vocab(self):
|
||||
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
|
||||
|
||||
self.assertEqual(vocab_keys[0], "<s>")
|
||||
self.assertEqual(vocab_keys[1], "<pad>")
|
||||
self.assertEqual(len(vocab_keys), 1_008)
|
||||
|
||||
def test_vocab_size(self):
|
||||
self.assertEqual(self.get_tokenizer().vocab_size, 1_008)
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = XGLMTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
tokens = tokenizer.tokenize("This is a test")
|
||||
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens),
|
||||
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
|
||||
)
|
||||
|
||||
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(
|
||||
tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"9",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"é",
|
||||
".",
|
||||
],
|
||||
)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids,
|
||||
[
|
||||
value + tokenizer.fairseq_offset
|
||||
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
|
||||
],
|
||||
)
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(
|
||||
back_tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"<unk>",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"<unk>",
|
||||
".",
|
||||
],
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def big_tokenizer(self):
|
||||
return XGLMTokenizer.from_pretrained("facebook/xglm-564M")
|
||||
|
||||
def test_picklable_without_disk(self):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
shutil.copyfile(SAMPLE_VOCAB, f.name)
|
||||
tokenizer = XGLMTokenizer(f.name, keep_accents=True)
|
||||
pickled_tokenizer = pickle.dumps(tokenizer)
|
||||
pickle.loads(pickled_tokenizer)
|
||||
|
||||
def test_rust_and_python_full_tokenizers(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
sequence = "I was born in 92000, and this is falsé."
|
||||
|
||||
tokens = tokenizer.tokenize(sequence)
|
||||
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||
self.assertListEqual(tokens, rust_tokens)
|
||||
|
||||
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
ids = tokenizer.encode(sequence)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
@slow
|
||||
def test_tokenization_base_easy_symbols(self):
|
||||
symbols = "Hello World!"
|
||||
original_tokenizer_encodings = [2, 31227, 4447, 35]
|
||||
|
||||
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
|
||||
|
||||
@slow
|
||||
def test_tokenization_base_hard_symbols(self):
|
||||
symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to unk, such as saoneuhaoesuth'
|
||||
# fmt: off
|
||||
original_tokenizer_encodings = [2, 1018, 67, 11, 1988, 2617, 5631, 278, 11, 3407, 48, 71630, 28085, 4, 3234, 157, 13, 6, 5, 6, 4, 3526, 768, 15, 659, 57, 298, 3983, 864, 129, 21, 6, 5, 13675, 377, 652, 7580, 10341, 155, 2817, 422, 1666, 7, 1674, 53, 113, 202277, 17892, 33, 60, 87, 4, 3234, 157, 61, 2667, 52376, 19, 88, 23, 735]
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
|
||||
|
||||
@slow
|
||||
def test_tokenizer_integration(self):
|
||||
# fmt: off
|
||||
expected_encoding = {
|
||||
'input_ids': [[2, 108825, 1163, 15, 88010, 473, 15898, 157, 13672, 1857, 312, 8, 238021, 1163, 53, 13672, 1857, 312, 8, 53283, 182396, 8, 18566, 16, 36733, 4101, 8, 230, 244017, 122553, 7, 15, 132597, 4, 293, 12511, 7610, 4, 3414, 132597, 9, 4, 32361, 362, 4, 734, 28512, 32569, 18, 4, 32361, 26096, 14982, 73, 18715, 21433, 235261, 15, 492, 12427, 16, 53, 18715, 21433, 65454, 15, 23659, 563, 16, 278, 597, 2843, 595, 7931, 182396, 64186, 22, 886, 595, 132981, 53, 25540, 3449, 43982, 39901, 5951, 878, 330, 4, 27694, 80269, 312, 53, 6517, 11780, 611, 20408, 5], [2, 6, 132597, 67, 42897, 33, 592, 8, 163729, 25540, 361, 136997, 109514, 173230, 7, 501, 60, 102913, 196, 5631, 235, 63243, 473, 6, 231757, 74, 5277, 7905, 53, 3095, 37317, 22, 454, 183874, 5], [2, 268, 31298, 46530, 6, 132935, 43831, 7, 597, 32, 24, 3688, 9865, 5]],
|
||||
'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
||||
} # noqa: E501
|
||||
# fmt: on
|
||||
|
||||
self.tokenizer_integration_test_util(
|
||||
expected_encoding=expected_encoding,
|
||||
model_name="facebook/xglm-564M",
|
||||
padding=False,
|
||||
)
|
@ -112,6 +112,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"ViltForImagesAndTextClassification",
|
||||
"ViltForImageAndTextRetrieval",
|
||||
"ViltForMaskedLM",
|
||||
"XGLMEncoder",
|
||||
"XGLMDecoder",
|
||||
"XGLMDecoderWrapper",
|
||||
"PerceiverForMultimodalAutoencoding",
|
||||
"PerceiverForOpticalFlow",
|
||||
"SegformerDecodeHead",
|
||||
|
Loading…
Reference in New Issue
Block a user