Add XGLM models (#14876)

* add xglm

* update vocab size

* fix model name

* style and tokenizer

* typo

* no mask token

* fix pos embed compute

* fix args

* fix tokenizer

* fix positions

* fix tokenization

* style and dic fixes

* fix imports

* add fast tokenizer

* update names

* add pt tests

* fix tokenizer

* fix typo

* fix tokenizer import

* fix fast tokenizer

* fix tokenizer

* fix converter

* add tokenizer test

* update checkpoint names

* fix tokenizer tests

* fix slow tests

* add copied from comments

* rst -> mdx

* flax model

* update flax tests

* quality

* style

* doc

* update index and readme

* fix copies

* fix doc

* update toctrr

* fix indent

* minor fixes

* fix config doc

* don't save embed_pos weights

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* address Sylvains commnets, few doc fixes

* fix check_repo

* align order of arguments

* fix copies

* fix labels

* remove unnecessary mapping

* fix saving tokenizer

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil 2022-01-28 18:55:23 +01:00 committed by GitHub
parent b6b79faa7e
commit d25e25ee2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 3697 additions and 0 deletions

View File

@ -319,6 +319,7 @@ AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Ch
1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/master/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
1. **[XGLM](https://huggingface.co/docs/master/transformers/model_doc/xglm)** (From Facebook AI) released with the paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlm-prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.

View File

@ -297,6 +297,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/master/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
1. **[XGLM](https://huggingface.co/docs/master/transformers/model_doc/xglm)** (From Facebook AI) released with the paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlm-prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.

View File

@ -321,6 +321,7 @@ conda install -c huggingface transformers
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (来自 Facebook AI) 伴随论文 [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) 由 Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli 发布。
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/master/transformers/model_doc/wav2vec2_phoneme)** (来自 Facebook AI) 伴随论文 [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) 由 Qiantong Xu, Alexei Baevski, Michael Auli 发布。
1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
1. **[XGLM](https://huggingface.co/docs/master/transformers/model_doc/xglm)** (From Facebook AI) released with the paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (来自 Facebook) 伴随论文 [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) 由 Guillaume Lample and Alexis Conneau 发布。
1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlm-prophetnet)** (来自 Microsoft Research) 伴随论文 [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) 由 Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou 发布。
1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (来自 Facebook AI), 伴随论文 [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) 由 Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov 发布。

View File

@ -333,6 +333,7 @@ conda install -c huggingface transformers
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/master/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
1. **[XGLM](https://huggingface.co/docs/master/transformers/model_doc/xglm)** (From Facebook AI) released with the paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlm-prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.

View File

@ -304,6 +304,8 @@
title: Wav2Vec2Phoneme
- local: model_doc/wavlm
title: WavLM
- local: model_doc/xglm
title: XGLM
- local: model_doc/xlm
title: XLM
- local: model_doc/xlm-prophetnet

View File

@ -178,6 +178,7 @@ conversion utilities for the following models.
1. **[WavLM](model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
1. **[Wav2Vec2](model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/master/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
1. **[XGLM](https://huggingface.co/docs/master/transformers/model_doc/xglm)** (From Facebook AI) released with the paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
1. **[XLM](model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
1. **[XLM-ProphetNet](model_doc/xlm-prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
1. **[XLM-RoBERTa](model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
@ -278,6 +279,7 @@ Flax), PyTorch, and/or TensorFlow.
| ViTMAE | ❌ | ❌ | ✅ | ❌ | ❌ |
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
| XGLM | ✅ | ✅ | ✅ | ❌ | ✅ |
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |

View File

@ -0,0 +1,75 @@
<!--Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# XGLM
## Overview
The XGLM model was proposed in [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668)
by Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal,
Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo,
Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
The abstract from the paper is the following:
*Large-scale autoregressive language models such as GPT-3 are few-shot learners that can perform a wide range of language
tasks without fine-tuning. While these models are known to be able to jointly represent many different languages,
their training data is dominated by English, potentially limiting their cross-lingual generalization.
In this work, we train multilingual autoregressive language models on a balanced corpus covering a diverse set of languages,
and study their few- and zero-shot learning capabilities in a wide range of tasks. Our largest model with 7.5 billion parameters
sets new state of the art in few-shot learning in more than 20 representative languages, outperforming GPT-3 of comparable size
in multilingual commonsense reasoning (with +7.4% absolute accuracy improvement in 0-shot settings and +9.4% in 4-shot settings)
and natural language inference (+5.4% in each of 0-shot and 4-shot settings). On the FLORES-101 machine translation benchmark,
our model outperforms GPT-3 on 171 out of 182 translation directions with 32 training examples, while surpassing the
official supervised baseline in 45 directions. We present a detailed analysis of where the model succeeds and fails,
showing in particular that it enables cross-lingual in-context learning on some tasks, while there is still room for improvement
on surface form robustness and adaptation to tasks that do not have a natural cloze form. Finally, we evaluate our models
in social value tasks such as hate speech detection in five languages and find it has limitations similar to comparable sized GPT-3 models.*
This model was contributed by [Suraj](https://huggingface.co/valhalla). The original code can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/xglm).
## XGLMConfig
[[autodoc]] XGLMConfig
## XGLMTokenizer
[[autodoc]] XGLMTokenizer
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary
## XGLMTokenizerFast
[[autodoc]] XGLMTokenizerFast
## XGLMModel
[[autodoc]] XGLMModel
- forward
## XGLMForCausalLM
[[autodoc]] XGLMForCausalLM
- forward
## FlaxXGLMModel
[[autodoc]] FlaxXGLMModel
- __call__
## FlaxXGLMForCausalLM
[[autodoc]] FlaxXGLMForCausalLM
- __call__

View File

@ -328,6 +328,7 @@ _import_structure = {
"WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"WavLMConfig",
],
"models.xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"],
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
@ -408,6 +409,7 @@ if is_sentencepiece_available():
_import_structure["models.rembert"].append("RemBertTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
_import_structure["models.t5"].append("T5Tokenizer")
_import_structure["models.xglm"].append("XGLMTokenizer")
_import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer")
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
_import_structure["models.xlnet"].append("XLNetTokenizer")
@ -422,6 +424,7 @@ else:
if is_tokenizers_available():
# Fast tokenizers
_import_structure["models.realm"].append("RealmTokenizerFast")
_import_structure["models.xglm"].append("XGLMTokenizerFast")
_import_structure["models.fnet"].append("FNetTokenizerFast")
_import_structure["models.roformer"].append("RoFormerTokenizerFast")
_import_structure["models.clip"].append("CLIPTokenizerFast")
@ -1461,6 +1464,14 @@ if is_torch_available():
"WavLMPreTrainedModel",
]
)
_import_structure["models.xglm"].extend(
[
"XGLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"XGLMForCausalLM",
"XGLMModel",
"XGLMPreTrainedModel",
]
)
_import_structure["models.xlm"].extend(
[
"XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -2203,6 +2214,13 @@ if is_flax_available():
_import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
)
_import_structure["models.xglm"].extend(
[
"FlaxXGLMForCausalLM",
"FlaxXGLMModel",
"FlaxXGLMPreTrainedModel",
]
)
else:
from .utils import dummy_flax_objects
@ -2463,6 +2481,7 @@ if TYPE_CHECKING:
from .models.wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizer
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from .models.wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig
from .models.xglm import XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XGLMConfig
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
@ -2544,6 +2563,7 @@ if TYPE_CHECKING:
from .models.rembert import RemBertTokenizer
from .models.speech_to_text import Speech2TextTokenizer
from .models.t5 import T5Tokenizer
from .models.xglm import XGLMTokenizer
from .models.xlm_prophetnet import XLMProphetNetTokenizer
from .models.xlm_roberta import XLMRobertaTokenizer
from .models.xlnet import XLNetTokenizer
@ -2591,6 +2611,7 @@ if TYPE_CHECKING:
from .models.splinter import SplinterTokenizerFast
from .models.squeezebert import SqueezeBertTokenizerFast
from .models.t5 import T5TokenizerFast
from .models.xglm import XGLMTokenizerFast
from .models.xlm_roberta import XLMRobertaTokenizerFast
from .models.xlnet import XLNetTokenizerFast
from .tokenization_utils_fast import PreTrainedTokenizerFast
@ -3404,6 +3425,7 @@ if TYPE_CHECKING:
WavLMModel,
WavLMPreTrainedModel,
)
from .models.xglm import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMPreTrainedModel
from .models.xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMForMultipleChoice,
@ -4018,6 +4040,7 @@ if TYPE_CHECKING:
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
else:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.

View File

@ -910,6 +910,35 @@ class BlenderbotConverter(Converter):
return tokenizer
class XGLMConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
# fmt: off
vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)]
# fmt: on
return vocab
def unk_id(self, proto):
unk_id = 3
return unk_id
def post_processor(self):
return processors.TemplateProcessing(
single="</s> $A",
pair="</s> $A </s> </s> $B",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
SLOW_TO_FAST_CONVERTERS = {
"AlbertTokenizer": AlbertConverter,
"BartTokenizer": RobertaConverter,
@ -953,6 +982,7 @@ SLOW_TO_FAST_CONVERTERS = {
"XLMRobertaTokenizer": XLMRobertaConverter,
"XLNetTokenizer": XLNetConverter,
"SplinterTokenizer": SplinterConverter,
"XGLMTokenizer": XGLMConverter,
}

View File

@ -116,6 +116,7 @@ from . import (
wav2vec2_phoneme,
wav2vec2_with_lm,
wavlm,
xglm,
xlm,
xlm_prophetnet,
xlm_roberta,

View File

@ -36,6 +36,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("vit_mae", "ViTMAEConfig"),
("realm", "RealmConfig"),
("nystromformer", "NystromformerConfig"),
("xglm", "XGLMConfig"),
("imagegpt", "ImageGPTConfig"),
("qdqbert", "QDQBertConfig"),
("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
@ -128,6 +129,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xglm", "XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@ -208,6 +210,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("vit_mae", "ViTMAE"),
("realm", "Realm"),
("nystromformer", "Nystromformer"),
("xglm", "XGLM"),
("imagegpt", "ImageGPT"),
("qdqbert", "QDQBert"),
("vision-encoder-decoder", "Vision Encoder decoder"),

View File

@ -33,6 +33,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("vilt", "ViltModel"),
("vit_mae", "ViTMAEModel"),
("nystromformer", "NystromformerModel"),
("xglm", "XGLMModel"),
("imagegpt", "ImageGPTModel"),
("qdqbert", "QDQBertModel"),
("fnet", "FNetModel"),
@ -209,6 +210,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("xglm", "XGLMForCausalLM"),
("qdqbert", "QDQBertLMHeadModel"),
("trocr", "TrOCRForCausalLM"),
("gptj", "GPTJForCausalLM"),

View File

@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("xglm", "FlaxXGLMModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("pegasus", "FlaxPegasusModel"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
@ -121,6 +122,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
]
)

View File

@ -0,0 +1,66 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
# rely on isort to merge the imports
from ...file_utils import _LazyModule, is_flax_available, is_tokenizers_available, is_torch_available
_import_structure = {
"configuration_xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"],
"tokenization_xglm": ["XGLMTokenizer"],
}
if is_tokenizers_available():
_import_structure["tokenization_xglm_fast"] = ["XGLMTokenizerFast"]
if is_torch_available():
_import_structure["modeling_xglm"] = [
"XGLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"XGLMForCausalLM",
"XGLMModel",
"XGLMPreTrainedModel",
]
if is_flax_available():
_import_structure["modeling_flax_xglm"] = [
"FlaxXGLMForCausalLM",
"FlaxXGLMModel",
"FlaxXGLMPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_xglm import XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XGLMConfig
from .tokenization_xglm import XGLMTokenizer
if is_tokenizers_available():
from .tokenization_xglm_fast import XGLMTokenizerFast
if is_torch_available():
from .modeling_xglm import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMPreTrainedModel
if is_flax_available():
from .modeling_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

View File

@ -0,0 +1,140 @@
# coding=utf-8
# Copyright The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" XGLM model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/xglm-564M": "https://huggingface.co/facebook/xglm-564M/resolve/main/config.json",
# See all XGLM models at https://huggingface.co/models?filter=xglm
}
class XGLMConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`XGLMModel`]. It is used to instantiate an XGLM
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the XGLM
[facebook/xglm-564M](https://huggingface.co/facebook/xglm-564M) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256008):
Vocabulary size of the XGLM model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`XGLMModel`] or [`FlaxXGLMModel`].
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
d_model (`int`, *optional*, defaults to 1024):
Dimension of the layers and the pooler layer.
ffn_dim (`int`, *optional*, defaults to 4096):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
num_layers (`int`, *optional*, defaults to 24):
Number of hidden layers Transformer decoder.
attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, dencoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
layerdrop: (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_embedding (`bool`, *optional*, defaults to `True`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Example:
```python
>>> from transformers import XGLMModel, XGLMConfig
>>> # Initializing a XGLM facebook/xglm-564M style configuration
>>> configuration = XGLMConfig()
>>> # Initializing a model from the facebook/xglm-564M style configuration
>>> model = XGLMModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "xglm"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_attention_heads": "attention_heads",
"hidden_size": "d_model",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
vocab_size=256008,
max_position_embeddings=2048,
d_model=1024,
ffn_dim=4096,
num_layers=24,
attention_heads=16,
activation_function="gelu",
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.0,
layerdrop=0.0,
init_std=0.02,
scale_embedding=True,
use_cache=True,
decoder_start_token_id=2,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.ffn_dim = ffn_dim
self.num_layers = num_layers
self.attention_heads = attention_heads
self.activation_function = activation_function
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.layerdrop = layerdrop
self.init_std = init_std
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.use_cache = use_cache
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)

View File

@ -0,0 +1,802 @@
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax XGLM model."""
import math
import random
from functools import partial
from typing import Optional, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import logging
from .configuration_xglm import XGLMConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/xglm-564M"
_CONFIG_FOR_DOC = "XGLMConfig"
_TOKENIZER_FOR_DOC = "XGLMTokenizer"
XGLM_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`XGLMConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
XGLM_INPUTS_DOCSTRING = r"""
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`~XGLMTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
def create_sinusoidal_positions(n_pos, dim, padding_idx=1):
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = np.exp(np.arange(half_dim) * -emb)
emb = np.expand_dims(np.arange(n_pos), 1) * np.expand_dims(emb, 0)
emb = np.concatenate([np.sin(emb), np.cos(emb)], 1)
emb = np.reshape(emb, (n_pos, dim))
if padding_idx is not None:
emb[padding_idx, :] = 0
return jnp.array(emb)
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
class FlaxXGLMAttention(nn.Module):
config: XGLMConfig
embed_dim: int
num_heads: int
dropout: float = 0.0
causal: bool = False
bias: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} "
"and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.out_proj = dense()
self.dropout_layer = nn.Dropout(rate=self.dropout)
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend
# to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__(
self,
hidden_states: jnp.ndarray,
key_value_states: Optional[jnp.ndarray] = None,
attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
batch_size = hidden_states.shape[0]
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states)
value_states = self.v_proj(key_value_states)
else:
# self_attention
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# handle cache prepare causal attention mask
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
)
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class FlaxXGLMDecoderLayer(nn.Module):
config: XGLMConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxXGLMAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
if self.config.add_cross_attention:
self.encoder_attn = FlaxXGLMAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
self.config.ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer.__call__
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
class FlaxXGLMDecoderLayerCollection(nn.Module):
config: XGLMConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
FlaxXGLMDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_layers)
]
self.layerdrop = self.config.layerdrop
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None, None)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_self_attns, all_cross_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class FlaxXGLMModule(nn.Module):
config: XGLMConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# XGLM is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
self.embed_positions = create_sinusoidal_positions(
self.config.max_position_embeddings + self.offset, embed_dim
)
self.layers = FlaxXGLMDecoderLayerCollection(self.config, self.dtype)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# embed positions
position_ids = position_ids + self.offset
positions = jnp.take(self.embed_positions, position_ids, axis=0)
hidden_states = inputs_embeds + positions
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
if not return_dict:
outputs = (last_hidden_states,) + outputs[1:]
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel):
config_class = XGLMConfig
base_model_prefix: str = "model"
module_class: nn.Module = None
def __init__(
self,
config: XGLMConfig,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
else:
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
return module_init_outputs["params"]
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
attention_mask = jnp.ones_like(input_ids, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
past_key_values: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if encoder_hidden_states is not None and encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
# changed by FlaxXGLMAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
return outputs
@add_start_docstrings(
"The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.",
XGLM_START_DOCSTRING,
)
class FlaxXGLMModel(FlaxXGLMPreTrainedModel):
module_class = FlaxXGLMModule
append_call_sample_docstring(
FlaxXGLMModel,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxBaseModelOutputWithPastAndCrossAttentions,
_CONFIG_FOR_DOC,
)
class FlaxXGLMForCausalLMModule(nn.Module):
config: XGLMConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.model = FlaxXGLMModule(self.config, self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids,
attention_mask,
position_ids,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.model.variables["params"]["embed_tokens"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + outputs[1:]
return FlaxCausalLMOutputWithCrossAttentions(
logits=lm_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@add_start_docstrings(
"""
The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
""",
XGLM_START_DOCSTRING,
)
class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel):
module_class = FlaxXGLMForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since GPT2 uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring(
FlaxXGLMForCausalLM,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)

View File

@ -0,0 +1,943 @@
# coding=utf-8
# Copyright 2021 The Fairseq Authors The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch XGLM model."""
import math
import random
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_xglm import XGLMConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/xglm-564M"
_CONFIG_FOR_DOC = "XGLMConfig"
_TOKENIZER_FOR_DOC = "XGLMTokenizer"
XGLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/xglm-564M",
# See all XGLM models at https://huggingface.co/models?filter=xglm
]
XGLM_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`XGLMConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
XGLM_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`XGLMTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
``input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
can choose to directly pass an embedded representation. This is useful if you want more control over how to
convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. If
`past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
`past_key_values`). This is useful if you want more control over how to convert `input_ids` indices into
associated vectors than the model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding with M2M100->XGLM
class XGLMSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__()
self.offset = 2
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
self.weights.detach_()
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
"""
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
"Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
@torch.no_grad()
def forward(
self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
):
if input_ids is not None:
bsz, seq_len = input_ids.size()
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
input_ids.device
)
else:
bsz, seq_len = inputs_embeds.size()[:-1]
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
Args:
inputs_embeds: torch.Tensor
Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape).contiguous()
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->XGLM
class XGLMAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class XGLMDecoderLayer(nn.Module):
def __init__(self, config: XGLMConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = XGLMAttention(
embed_dim=self.embed_dim,
num_heads=config.attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
if config.add_cross_attention:
self.crossattention = XGLMAttention(
embed_dim=self.embed_dim,
num_heads=config.attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
attention_mask (`torch.FloatTensor`): attention mask of size
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape *(seq_len, batch, embed_dim)*
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
*(encoder_attention_heads,)*.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size *(decoder_attention_heads,)*.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
class XGLMPreTrainedModel(PreTrainedModel):
config_class = XGLMConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, XGLMModel):
module.gradient_checkpointing = value
@add_start_docstrings(
"The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.",
XGLM_START_DOCSTRING,
)
class XGLMModel(XGLMPreTrainedModel):
"""
Transformer decoder consisting of *config.num_layers* layers. Each layer is a [`XGLMDecoderLayer`]
Args:
config: XGLMConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = None):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_positions = XGLMSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
config.pad_token_id,
)
self.layers = nn.ModuleList([XGLMDecoderLayer(config) for _ in range(config.num_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`~XGLMTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (
len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
@add_start_docstrings(
"""
The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
""",
XGLM_START_DOCSTRING,
)
class XGLMForCausalLM(XGLMPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"model.embed_positions.weights",
r"lm_head.weight",
]
_keys_to_ignore_on_save = [
r"model.embed_positions.weights",
]
def __init__(self, config):
super().__init__(config)
self.model = XGLMModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
# shift labels and add a pad token to the end
shift_labels = labels.new_zeros(labels.shape)
shift_labels[:, :-1] = labels[:, 1:].clone()
shift_labels[:, -1] = self.config.pad_token_id
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

View File

@ -0,0 +1,313 @@
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for ."""
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
SPIECE_UNDERLINE = ""
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"facebook/xglm-564M": "https://huggingface.co/facebook/xglm-564M/resolve/main/sentencepiece.bpe.model",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/xglm-564M": 1024,
}
class XGLMTokenizer(PreTrainedTokenizer):
"""
Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
[SentencePiece](https://github.com/google/sentencepiece).
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
Path to the vocabulary file.
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the beginning of
sequence. The token used is the `cls_token`.
</Tip>
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.
</Tip>
sep_token (`str`, *optional*, defaults to `"</s>"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
cls_token (`str`, *optional*, defaults to `"<s>"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
mask_token (`str`, *optional*, defaults to `"<mask>"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
Additional special tokens used by the tokenizer.
sp_model_kwargs (`dict`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
Attributes:
sp_model (`SentencePieceProcessor`):
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs
) -> None:
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
# Compatibility with the original tokenizer
self.num_madeup_words = 7
madeup_words = [f"<madeupword{i}>" for i in range(self.num_madeup_words)]
kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", [])
kwargs["additional_special_tokens"] += [
word for word in madeup_words if word not in kwargs["additional_special_tokens"]
]
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file
# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self.fairseq_offset = 1
# Mimic fairseq token-to-id alignment for the first 4 token
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
sp_size = len(self.sp_model)
madeup_words = {f"<madeupword{i}>": sp_size + i + self.fairseq_offset for i in range(self.num_madeup_words)}
self.fairseq_tokens_to_ids.update(madeup_words)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state
def __setstate__(self, d):
self.__dict__ = d
# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. An XLM-RoBERTa sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.sep_token_id] + token_ids_0
sep = [self.sep_token_id]
return sep + token_ids_0 + sep + sep + token_ids_1
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0))
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1))
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
not make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
if token_ids_1 is None:
return len(sep + token_ids_0) * [0]
return len(sep + token_ids_0 + sep + sep + token_ids_1) * [0]
@property
def vocab_size(self):
return len(self.sp_model) + self.fairseq_offset + self.num_madeup_words
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> List[str]:
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token)
# Need to return unknown token if the SP model returned 0
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -0,0 +1,206 @@
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for XGLM."""
import os
from shutil import copyfile
from typing import List, Optional, Tuple
from ...file_utils import is_sentencepiece_available
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
if is_sentencepiece_available():
from .tokenization_xglm import XGLMTokenizer
else:
XGLMTokenizer = None
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"facebook/xglm-564M": "https://huggingface.co/facebook/xglm-564M/resolve/main/sentencepiece.bpe.model",
},
"tokenizer_file": {
"facebook/xglm-564M": "https://huggingface.co/facebook/xglm-564M/resolve/main/tokenizer.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/xglm-564M": 1024,
}
class XGLMTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "fast" XGLM tokenizer (backed by HuggingFace's *tokenizers* library). Adapted from [`RobertaTokenizer`]
and [`XLNetTokenizer`]. Based on
[BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
Path to the vocabulary file.
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the beginning of
sequence. The token used is the `cls_token`.
</Tip>
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.
</Tip>
sep_token (`str`, *optional*, defaults to `"</s>"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
cls_token (`str`, *optional*, defaults to `"<s>"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
Additional special tokens used by the tokenizer.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = XGLMTokenizer
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
**kwargs
):
# Compatibility with the original tokenizer
self.num_madeup_words = 7
madeup_words = [f"<madeupword{i}>" for i in range(self.num_madeup_words)]
kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", [])
kwargs["additional_special_tokens"] += [
word for word in madeup_words if word not in kwargs["additional_special_tokens"]
]
super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
**kwargs,
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. An XLM-RoBERTa sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.sep_token_id] + token_ids_0
sep = [self.sep_token_id]
return sep + token_ids_0 + sep + sep + token_ids_1
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
not make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
if token_ids_1 is None:
return len(sep + token_ids_0) * [0]
return len(sep + token_ids_0 + sep + sep + token_ids_1) * [0]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)

View File

@ -961,3 +961,24 @@ class FlaxWav2Vec2PreTrainedModel(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXGLMForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXGLMModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXGLMPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

View File

@ -3848,6 +3848,30 @@ class WavLMPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
XGLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
class XGLMForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class XGLMModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class XGLMPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -136,6 +136,13 @@ class T5Tokenizer(metaclass=DummyObject):
requires_backends(self, ["sentencepiece"])
class XGLMTokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece"])
class XLMProphetNetTokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"]

View File

@ -297,6 +297,13 @@ class T5TokenizerFast(metaclass=DummyObject):
requires_backends(self, ["tokenizers"])
class XGLMTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
class XLMRobertaTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

View File

@ -0,0 +1,347 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import transformers
from transformers import XGLMConfig, XGLMTokenizer, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_sentencepiece, slow
from .test_generation_flax_utils import FlaxGenerationTesterMixin
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
import numpy as np
import jax
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.xglm.modeling_flax_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel
if is_torch_available():
import torch
@require_flax
class FlaxXGLMModelTester:
def __init__(
self,
parent,
batch_size=14,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
d_model=32,
num_hidden_layers=5,
num_attention_heads=4,
ffn_dim=37,
activation_function="gelu",
activation_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
initializer_range=0.02,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = d_model
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.ffn_dim = ffn_dim
self.activation_function = activation_function
self.activation_dropout = activation_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = None
self.bos_token_id = 0
self.eos_token_id = 2
self.pad_token_id = 1
def prepare_config_and_inputs(self):
input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length], self.vocab_size), 3, self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
config = XGLMConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
num_layers=self.num_hidden_layers,
attention_heads=self.num_attention_heads,
ffn_dim=self.ffn_dim,
activation_function=self.activation_function,
activation_dropout=self.activation_dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
)
return (config, input_ids, input_mask)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
def prepare_config_and_inputs_for_decoder(self):
config, input_ids, attention_mask = self.prepare_config_and_inputs()
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
attention_mask=attention_mask,
past_key_values=outputs_cache.past_key_values,
position_ids=position_ids,
)
outputs = model(input_ids)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
attention_mask_cache = jnp.concatenate(
[attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
axis=-1,
)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask_cache,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
past_key_values=outputs_cache.past_key_values,
attention_mask=attention_mask_cache,
position_ids=position_ids,
)
outputs = model(input_ids, attention_mask=attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
@require_sentencepiece
@require_flax
class FlaxXGLMModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
all_model_classes = (FlaxXGLMModel, FlaxXGLMForCausalLM) if is_flax_available() else ()
all_generative_model_classes = (FlaxXGLMForCausalLM,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxXGLMModelTester(self)
def test_use_cache_forward(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)
def test_use_cache_forward_with_attn_mask(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward_with_attn_mask(
model_class_name, config, input_ids, attention_mask
)
@slow
def test_batch_generation(self):
tokenizer = XGLMTokenizer.from_pretrained("XGLM", padding_side="left")
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
model = FlaxXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
model.config.num_beams = 1
model.config.do_sample = False
jit_generate = jax.jit(model.generate)
output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
expected_string = [
"Hello this is a long string of questions, but I'm not sure if I'm",
"Hey, I'm a newbie to the forum and I'",
]
self.assertListEqual(output_string, expected_string)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
# make sure weights are tied in PyTorch
pt_model.tie_weights()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("facebook/xglm-564M")
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)

471
tests/test_modeling_xglm.py Normal file
View File

@ -0,0 +1,471 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import math
import unittest
from transformers import XGLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
import torch
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
class XGLMModelTester:
def __init__(
self,
parent,
batch_size=14,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
d_model=32,
num_hidden_layers=5,
num_attention_heads=4,
ffn_dim=37,
activation_function="gelu",
activation_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
initializer_range=0.02,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = d_model
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.ffn_dim = ffn_dim
self.activation_function = activation_function
self.activation_dropout = activation_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = None
self.bos_token_id = 0
self.eos_token_id = 2
self.pad_token_id = 1
def get_large_model_config(self):
return XGLMConfig.from_pretrained("facebook/xglm-564M")
def prepare_config_and_inputs(
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(3)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
return (
config,
input_ids,
input_mask,
head_mask,
)
def get_config(
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
):
return XGLMConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
num_layers=self.num_hidden_layers,
attention_heads=self.num_attention_heads,
ffn_dim=self.ffn_dim,
activation_function=self.activation_function,
activation_dropout=self.activation_dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
input_mask,
head_mask,
) = self.prepare_config_and_inputs()
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
input_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_xglm_model(self, config, input_ids, input_mask, head_mask, *args):
model = XGLMModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, head_mask=head_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(len(result.past_key_values), config.num_hidden_layers)
def create_and_check_xglm_model_past(self, config, input_ids, input_mask, head_mask, *args):
model = XGLMModel(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and token_type_ids
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_xglm_model_attention_mask_past(self, config, input_ids, input_mask, head_mask, *args):
model = XGLMModel(config=config)
model.to(torch_device)
model.eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = self.seq_length // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.zeros((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_xglm_model_past_large_inputs(self, config, input_ids, input_mask, head_mask, *args):
model = XGLMModel(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
output, past = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=1)
# append to next input_ids
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past)[
"last_hidden_state"
]
self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, *args):
model = XGLMForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, *args, gradient_checkpointing=False
):
model = XGLMForCausalLM(config)
model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
result = model(input_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
result.loss.backward()
def create_and_check_xglm_weight_initialization(self, config, *args):
model = XGLMModel(config)
model_std = model.config.initializer_range / math.sqrt(2 * model.config.num_hidden_layers)
for key in model.state_dict().keys():
if "c_proj" in key and "weight" in key:
self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
head_mask,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"head_mask": head_mask,
}
return config, inputs_dict
@require_torch
class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else ()
test_missing_keys = False
test_pruning = False
def setUp(self):
self.model_tester = XGLMModelTester(self)
self.config_tester = ConfigTester(self, config_class=XGLMConfig, n_embd=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_xglm_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xglm_model(*config_and_inputs)
def test_xglm_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xglm_model_past(*config_and_inputs)
def test_xglm_model_att_mask_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xglm_model_attention_mask_past(*config_and_inputs)
def test_xglm_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xglm_model_past_large_inputs(*config_and_inputs)
def test_xglm_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
def test_xglm_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
def test_xglm_weight_initialization(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
@slow
def test_batch_generation(self):
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
model.to(torch_device)
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
tokenizer.padding_side = "left"
# use different length sentences to test batching
sentences = [
"Hello, my dog is a little",
"Today, I",
]
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(torch_device)
outputs = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
)
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
output_non_padded = model.generate(input_ids=inputs_non_padded)
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
expected_output_sentence = [
"Hello, my dog is a little bit of a shy one, but he is very friendly",
"Today, I am going to share with you a few of my favorite things",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
@slow
def test_model_from_pretrained(self):
for model_name in XGLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = XGLMModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch
class XGLMModelLanguageGenerationTest(unittest.TestCase):
def _test_lm_generate_xglm_helper(
self,
gradient_checkpointing=False,
verify_outputs=True,
):
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
if gradient_checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device)
input_ids = torch.tensor([[2, 268, 9865]], dtype=torch.long, device=torch_device) # The dog
# </s> The dog is a very friendly dog. He is very affectionate and loves to play with other
# fmt: off
expected_output_ids = [2, 268, 9865, 67, 11, 1988, 57252, 9865, 5, 984, 67, 1988, 213838, 1658, 53, 70446, 33, 6657, 278, 1581]
# fmt: on
output_ids = model.generate(input_ids, do_sample=False, num_beams=1)
if verify_outputs:
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow
def test_lm_generate_xglm(self):
self._test_lm_generate_xglm_helper()
@slow
def test_lm_generate_xglm_with_gradient_checkpointing(self):
self._test_lm_generate_xglm_helper(gradient_checkpointing=True)
@slow
def test_xglm_sample(self):
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
model.to(torch_device)
torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
input_ids = tokenized.input_ids.to(torch_device)
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = "Today is a nice day and I am happy to show you all about a recent project for my"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
@slow
def test_xglm_sample_max_time(self):
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
model.to(torch_device)
torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
input_ids = tokenized.input_ids.to(torch_device)
MAX_TIME = 0.15
start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.25 * MAX_TIME))

View File

@ -0,0 +1,203 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import shutil
import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, XGLMTokenizer, XGLMTokenizerFast
from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow
from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
@require_sentencepiece
@require_tokenizers
class XGLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = XGLMTokenizer
rust_tokenizer_class = XGLMTokenizerFast
test_rust_tokenizer = True
test_sentencepiece = True
def setUp(self):
super().setUp()
# We have a SentencePiece fixture for testing
tokenizer = XGLMTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def test_convert_token_and_id(self):
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
token = "<pad>"
token_id = 1
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
self.assertEqual(vocab_keys[0], "<s>")
self.assertEqual(vocab_keys[1], "<pad>")
self.assertEqual(len(vocab_keys), 1_008)
def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 1_008)
def test_full_tokenizer(self):
tokenizer = XGLMTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids,
[
value + tokenizer.fairseq_offset
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
],
)
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
@cached_property
def big_tokenizer(self):
return XGLMTokenizer.from_pretrained("facebook/xglm-564M")
def test_picklable_without_disk(self):
with tempfile.NamedTemporaryFile() as f:
shutil.copyfile(SAMPLE_VOCAB, f.name)
tokenizer = XGLMTokenizer(f.name, keep_accents=True)
pickled_tokenizer = pickle.dumps(tokenizer)
pickle.loads(pickled_tokenizer)
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
sequence = "I was born in 92000, and this is falsé."
tokens = tokenizer.tokenize(sequence)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
ids = tokenizer.encode(sequence, add_special_tokens=False)
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
self.assertListEqual(ids, rust_ids)
rust_tokenizer = self.get_rust_tokenizer()
ids = tokenizer.encode(sequence)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
@slow
def test_tokenization_base_easy_symbols(self):
symbols = "Hello World!"
original_tokenizer_encodings = [2, 31227, 4447, 35]
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
@slow
def test_tokenization_base_hard_symbols(self):
symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to unk, such as saoneuhaoesuth'
# fmt: off
original_tokenizer_encodings = [2, 1018, 67, 11, 1988, 2617, 5631, 278, 11, 3407, 48, 71630, 28085, 4, 3234, 157, 13, 6, 5, 6, 4, 3526, 768, 15, 659, 57, 298, 3983, 864, 129, 21, 6, 5, 13675, 377, 652, 7580, 10341, 155, 2817, 422, 1666, 7, 1674, 53, 113, 202277, 17892, 33, 60, 87, 4, 3234, 157, 61, 2667, 52376, 19, 88, 23, 735]
# fmt: on
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
@slow
def test_tokenizer_integration(self):
# fmt: off
expected_encoding = {
'input_ids': [[2, 108825, 1163, 15, 88010, 473, 15898, 157, 13672, 1857, 312, 8, 238021, 1163, 53, 13672, 1857, 312, 8, 53283, 182396, 8, 18566, 16, 36733, 4101, 8, 230, 244017, 122553, 7, 15, 132597, 4, 293, 12511, 7610, 4, 3414, 132597, 9, 4, 32361, 362, 4, 734, 28512, 32569, 18, 4, 32361, 26096, 14982, 73, 18715, 21433, 235261, 15, 492, 12427, 16, 53, 18715, 21433, 65454, 15, 23659, 563, 16, 278, 597, 2843, 595, 7931, 182396, 64186, 22, 886, 595, 132981, 53, 25540, 3449, 43982, 39901, 5951, 878, 330, 4, 27694, 80269, 312, 53, 6517, 11780, 611, 20408, 5], [2, 6, 132597, 67, 42897, 33, 592, 8, 163729, 25540, 361, 136997, 109514, 173230, 7, 501, 60, 102913, 196, 5631, 235, 63243, 473, 6, 231757, 74, 5277, 7905, 53, 3095, 37317, 22, 454, 183874, 5], [2, 268, 31298, 46530, 6, 132935, 43831, 7, 597, 32, 24, 3688, 9865, 5]],
'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
} # noqa: E501
# fmt: on
self.tokenizer_integration_test_util(
expected_encoding=expected_encoding,
model_name="facebook/xglm-564M",
padding=False,
)

View File

@ -112,6 +112,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"ViltForImagesAndTextClassification",
"ViltForImageAndTextRetrieval",
"ViltForMaskedLM",
"XGLMEncoder",
"XGLMDecoder",
"XGLMDecoderWrapper",
"PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow",
"SegformerDecodeHead",