diff --git a/README.md b/README.md index 6706f4da6cf..f6e503896b6 100644 --- a/README.md +++ b/README.md @@ -217,6 +217,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[LED](https://huggingface.co/transformers/model_doc/led.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[Longformer](https://huggingface.co/transformers/model_doc/longformer.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LXMERT](https://huggingface.co/transformers/model_doc/lxmert.html)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. +1. **[M2M100](https://huggingface.co/transformers/model_doc/m2m_100.html)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MBart](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[MBart-50](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. diff --git a/docs/source/index.rst b/docs/source/index.rst index 53788e7e960..3dbc6b6dd65 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -161,57 +161,61 @@ and conversion utilities for the following models: 26. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -27. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +27. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual + Machine Translation `__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi + Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman + Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. +28. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -28. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +29. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -29. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +30. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -30. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +31. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -31. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +32. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -32. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +33. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -33. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +34. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -34. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +35. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -35. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +36. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -36. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +37. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -37. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +38. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -38. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +39. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -39. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +40. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -40. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +41. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -41. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +42. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -42. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +43. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -43. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +44. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -44. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +45. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. @@ -276,6 +280,8 @@ TensorFlow and/or Flax. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Longformer | ✅ | ✅ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| M2M100 | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | MPNet | ✅ | ✅ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Marian | ✅ | ❌ | ✅ | ✅ | ❌ | @@ -416,6 +422,7 @@ TensorFlow and/or Flax. model_doc/longformer model_doc/lxmert model_doc/marian + model_doc/m2m_100 model_doc/mbart model_doc/mobilebert model_doc/mpnet diff --git a/docs/source/model_doc/m2m_100.rst b/docs/source/model_doc/m2m_100.rst new file mode 100644 index 00000000000..b5c8d46bc91 --- /dev/null +++ b/docs/source/model_doc/m2m_100.rst @@ -0,0 +1,125 @@ +.. + Copyright 2020 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +M2M100 +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The M2M100 model was proposed in `Beyond English-Centric Multilingual Machine Translation +`__ by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, +Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy +Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. + +The abstract from the paper is the following: + +*Existing work in translation demonstrated the potential of massively multilingual machine translation by training a +single model able to translate between any pair of languages. However, much of this work is English-Centric by training +only on data which was translated from or to English. While this is supported by large sources of training data, it +does not reflect translation needs worldwide. In this work, we create a true Many-to-Many multilingual translation +model that can translate directly between any pair of 100 languages. We build and open source a training dataset that +covers thousands of language directions with supervised data, created through large-scale mining. Then, we explore how +to effectively increase model capacity through a combination of dense scaling and language-specific sparse parameters +to create high quality models. Our focus on non-English-Centric models brings gains of more than 10 BLEU when directly +translating between non-English directions while performing competitively to the best single systems of WMT. We +open-source our scripts so that others may reproduce the data, evaluation, and final M2M-100 model.* + + +Training and Generation +_______________________________________________________________________________________________________________________ + +M2M100 is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation tasks. As the model is +multilingual it expects the sequences in a certain format: A special language id token is used as prefix in both the +source and target text. The source text format is :obj:`[lang_code] X [eos]`, where :obj:`lang_code` is source language +id for source text and target language id for target text, with :obj:`X` being the source or target text. + +- Supervised Training + +.. code-block:: + + from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Tokenizer + + model = M2M100ForConditionalGeneration.from_pretrained('facebook/m2m100_418M') + tokenizer = M2M100Tokenizer.from_pretrained('facebook/m2m100_418M', src_lang="en", tgt_lang="fr") + + src_text = "Life is like a box of chocolates." + tgt_lang = "La vie est comme une boîte de chocolat." + + model_inputs = tokenizer(src_text, return_tensors="pt") + with tokenizer.as_target_tokenizer(): + labels = tokenizer(tgt_text, return_tensors="pt").input_ids + + loss = model(**model_inputs, labels=labels) # forward pass + + +- Generation + + M2M100 uses the :obj:`eos_token_id` as the :obj:`decoder_start_token_id` for generation with the target language id + being forced as the first generated token. To force the target language id as the first generated token, pass the + `forced_bos_token_id` parameter to the `generate` method. The following example shows how to translate between + Hindi to French and Chinese to English using the `facebook/m2m100_418M` checkpoint. + +.. code-block:: + + >>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer + + >>> hi_text = "जीवन एक चॉकलेट बॉक्स की तरह है।" + >>> chinese_text = "生活就像一盒巧克力。" + + >>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M") + >>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") + + >>> # translate Hindi to French + >>> tokenizer.src_lang = "hi" + >>> encoded_hi = tokenizer(hi_text, return_tensors="pt") + >>> generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.get_lang_id("fr")) + >>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + "La vie est comme une boîte de chocolat." + + >>> # translate Chinese to English + >>> tokenizer.src_lang = "ar_AR" + >>> encoded_zh = tokenizer(chinese_text, return_tensors="pt") + >>> generated_tokens = model.generate(**encoded_zh, forced_bos_token_id=tokenizer.get_lang_id("en")) + >>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + "Life is like a box of chocolate." + + +M2M100Config +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.M2M100Config + :members: + + +M2M100Tokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.M2M100Tokenizer + :members: build_inputs_with_special_tokens, get_special_tokens_mask, + create_token_type_ids_from_sequences, save_vocabulary + + +M2M100Model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.M2M100Model + :members: forward + + +M2M100ForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.M2M100ForConditionalGeneration + :members: forward + + diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index c1a71ad35d8..4a29ebf4eea 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -365,6 +365,12 @@ For the full list, refer to `https://huggingface.co/models `_) | +--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0856c68edcf..2b6a037892c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -134,6 +134,7 @@ _import_structure = { "Wav2Vec2FeatureExtractor", "Wav2Vec2Processor", ], + "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config", "M2M100Tokenizer"], "models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ @@ -385,6 +386,14 @@ if is_torch_available(): "Wav2Vec2PreTrainedModel", ] ) + _import_structure["models.m2m_100"].extend( + [ + "M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST", + "M2M100ForConditionalGeneration", + "M2M100Model", + ] + ) + _import_structure["models.convbert"].extend( [ "CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1348,6 +1357,7 @@ if TYPE_CHECKING: from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer + from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config, M2M100Tokenizer from .models.marian import MarianConfig from .models.mbart import MBartConfig from .models.mmbt import MMBTConfig @@ -1768,6 +1778,7 @@ if TYPE_CHECKING: LxmertVisualFeatureEncoder, LxmertXLayer, ) + from .models.m2m_100 import M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, M2M100ForConditionalGeneration, M2M100Model from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel from .models.mbart import ( MBartForCausalLM, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index f7f9a9e58de..d4957cb76cb 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -45,6 +45,7 @@ from . import ( led, longformer, lxmert, + m2m_100, marian, mbart, mmbt, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 338f2737575..4a9be13e52b 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -45,6 +45,7 @@ from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig +from ..m2m_100.configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config from ..marian.configuration_marian import MarianConfig from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig from ..mobilebert.configuration_mobilebert import MobileBertConfig @@ -76,6 +77,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( for pretrained_map in [ # Add archive maps here WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, + M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LED_PRETRAINED_CONFIG_ARCHIVE_MAP, BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -121,6 +123,7 @@ CONFIG_MAPPING = OrderedDict( [ # Add configs here ("wav2vec2", Wav2Vec2Config), + ("m2m_100", M2M100Config), ("convbert", ConvBertConfig), ("led", LEDConfig), ("blenderbot-small", BlenderbotSmallConfig), @@ -172,6 +175,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here ("wav2vec2", "Wav2Vec2"), + ("m2m_100", "M2M100"), ("convbert", "ConvBERT"), ("led", "LED"), ("blenderbot-small", "BlenderbotSmall"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 39e8b70b3ce..99a72320e3a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -158,6 +158,7 @@ from ..longformer.modeling_longformer import ( LongformerModel, ) from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel +from ..m2m_100.modeling_m2m_100 import M2M100ForConditionalGeneration, M2M100Model from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel from ..mbart.modeling_mbart import ( MBartForCausalLM, @@ -283,6 +284,7 @@ from .configuration_auto import ( LEDConfig, LongformerConfig, LxmertConfig, + M2M100Config, MarianConfig, MBartConfig, MobileBertConfig, @@ -314,6 +316,7 @@ MODEL_MAPPING = OrderedDict( [ # Base model mapping (Wav2Vec2Config, Wav2Vec2Model), + (M2M100Config, M2M100Model), (ConvBertConfig, ConvBertModel), (LEDConfig, LEDModel), (BlenderbotSmallConfig, BlenderbotSmallModel), @@ -397,6 +400,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( [ # Model with LM heads mapping (Wav2Vec2Config, Wav2Vec2ForMaskedLM), + (M2M100Config, M2M100ForConditionalGeneration), (ConvBertConfig, ConvBertForMaskedLM), (LEDConfig, LEDForConditionalGeneration), (BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), @@ -495,6 +499,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + (M2M100Config, M2M100ForConditionalGeneration), (LEDConfig, LEDForConditionalGeneration), (BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), (MT5Config, MT5ForConditionalGeneration), diff --git a/src/transformers/models/m2m_100/__init__.py b/src/transformers/models/m2m_100/__init__.py new file mode 100644 index 00000000000..5b521ab9370 --- /dev/null +++ b/src/transformers/models/m2m_100/__init__.py @@ -0,0 +1,67 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available + + +_import_structure = { + "configuration_m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], + "tokenization_m2m_100": ["M2M100Tokenizer"], +} + + +if is_torch_available(): + _import_structure["modeling_m2m_100"] = [ + "M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST", + "M2M100ForConditionalGeneration", + "M2M100Model", + "M2M100PreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config + from .tokenization_m2m_100 import M2M100Tokenizer + + if is_torch_available(): + from .modeling_m2m_100 import ( + M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, + M2M100ForConditionalGeneration, + M2M100Model, + M2M100PreTrainedModel, + ) + + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/m2m_100/configuration_m2m_100.py b/src/transformers/models/m2m_100/configuration_m2m_100.py new file mode 100644 index 00000000000..725be8f7965 --- /dev/null +++ b/src/transformers/models/m2m_100/configuration_m2m_100.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" M2M100 model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/m2m100_418M": "https://huggingface.co/facebook/m2m100_418M/resolve/main/config.json", + # See all M2M100 models at https://huggingface.co/models?filter=m2m_100 +} + + +class M2M100Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.M2M100Model`. It is used to + instantiate an M2M100 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the M2M100 `m2m100_418M + `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50265): + Vocabulary size of the M2M100 model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.M2M100Model` or + d_model (:obj:`int`, `optional`, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (:obj:`int`, `optional`, defaults to 12): + Number of encoder layers. + decoder_layers (:obj:`int`, `optional`, defaults to 12): + Number of decoder layers. + encoder_attention_heads (:obj:`int`, `optional`, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (:obj:`int`, `optional`, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + dropout (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (:obj:`int`, `optional`, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the encoder. See the `LayerDrop paper `__ for more details. + decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example:: + + >>> from transformers import M2M100Model, M2M100Config + + >>> # Initializing a M2M100 facebook/m2m100_418M style configuration + >>> configuration = M2M100Config() + + >>> # Initializing a model from the facebook/m2m100_418M style configuration + >>> model = M2M100Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "m2m_100" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=128112, + max_position_embeddings=1024, + encoder_layers=12, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.05, + decoder_layerdrop=0.05, + use_cache=True, + is_encoder_decoder=True, + activation_function="relu", + d_model=1024, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + scale_embedding=True, + gradient_checkpointing=False, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.gradient_checkpointing = gradient_checkpointing + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model diff --git a/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py b/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py new file mode 100644 index 00000000000..74580bc181f --- /dev/null +++ b/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py @@ -0,0 +1,85 @@ +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from torch import nn + +from transformers import M2M100Config, M2M100ForConditionalGeneration + + +def remove_ignore_keys_(state_dict): + ignore_keys = [ + "encoder.version", + "decoder.version", + "model.encoder.version", + "model.decoder.version", + "decoder.output_projection.weight", + "_float_tensor", + "encoder.embed_positions._float_tensor", + "decoder.embed_positions._float_tensor", + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +def convert_fairseq_m2m100_checkpoint_from_disk(checkpoint_path): + m2m_100 = torch.load(checkpoint_path, map_location="cpu") + args = m2m_100["args"] + state_dict = m2m_100["model"] + remove_ignore_keys_(state_dict) + vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] + + config = M2M100Config( + vocab_size=vocab_size, + max_position_embeddings=1024, + encoder_layers=args.encoder_layers, + decoder_layers=args.decoder_layers, + encoder_attention_heads=args.encoder_attention_heads, + decoder_attention_heads=args.decoder_attention_heads, + encoder_ffn_dim=args.encoder_ffn_embed_dim, + decoder_ffn_dim=args.decoder_ffn_embed_dim, + d_model=args.encoder_embed_dim, + encoder_layerdrop=args.encoder_layerdrop, + decoder_layerdrop=args.decoder_layerdrop, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_function="relu", + ) + + state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] + model = M2M100ForConditionalGeneration(config) + model.model.load_state_dict(state_dict) + model.lm_head = make_linear_from_emb(model.model.shared) + + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("fairseq_path", type=str, help="path to a model.pt on local filesystem.") + parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + model = convert_fairseq_m2m100_checkpoint_from_disk(args.fairseq_pathß) + model.save_pretrained(args.pytorch_dump_folder_path) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py new file mode 100755 index 00000000000..bb9f56a443e --- /dev/null +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -0,0 +1,1299 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch M2M100 model. """ + + +import math +import random +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...file_utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_m2m_100 import M2M100Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "M2M100Config" +_TOKENIZER_FOR_DOC = "M2M100Tokenizer" + + +M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/m2m100_418M", + # See all M2M100 models at https://huggingface.co/models?filter=m2m_100 +] + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +class M2M100SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weights = self.get_embedding(num_positions + self.offset, embedding_dim, padding_idx) + self.register_buffer("_float_tensor", torch.FloatTensor(1)) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + @torch.no_grad() + def forward( + self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0 + ): + if input_ids is not None: + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + else: + bsz, seq_len = inputs_embeds.size()[:-1] + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + if max_pos > self.weights.size(0): + self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx) + + self.weights = self.weights.to(self._float_tensor) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape).contiguous() + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->M2M100 +class M2M100Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + assert attn_weights.size() == ( + bsz * self.num_heads, + tgt_len, + src_len, + ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + + if attention_mask is not None: + assert attention_mask.size() == ( + bsz, + 1, + tgt_len, + src_len, + ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = F.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + assert attn_output.size() == ( + bsz * self.num_heads, + tgt_len, + self.head_dim, + ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + + attn_output = ( + attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .reshape(bsz, tgt_len, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100 +class M2M100EncoderLayer(nn.Module): + def __init__(self, config: M2M100Config): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = M2M100Attention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100 +class M2M100DecoderLayer(nn.Module): + def __init__(self, config: M2M100Config): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = M2M100Attention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = M2M100Attention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. + past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class M2M100PreTrainedModel(PreTrainedModel): + config_class = M2M100Config + base_model_prefix = "model" + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +M2M_100_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.M2M100Config`): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +M2M_100_GENERATION_EXAMPLE = r""" + Translation example:: + + >>> from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration + + >>> model = M2M100ForConditionalGeneration.from_pretrained('facebook/m2m100_418M') + >>> tokenizer = M2M100Tokenizer.from_pretrained('facebook/m2m100_418M') + + >>> text_to_translate = "Life is like a box of chocolates" + >>> model_inputs = tokenizer(text_to_translate, return_tensors='pt') + + >>> # translate to French + >>> gen_tokens = model.generate( **model_inputs, forced_bos_token_id=tok.get_lang_id("fr")) + >>> print(tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)) +""" + +M2M_100_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.M2M100Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Provide for translation and summarization training. By default, the model will create this tensor by + shifting the :obj:`input_ids` to the right, following the paper. + decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read :func:`modeling_m2m_100._prepare_decoder_inputs` + and modify to your needs. See diagram 1 in `the paper `__ for more + information on the default strategy. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: + :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, + `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class M2M100Encoder(M2M100PreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + :class:`M2M100EncoderLayer`. + + Args: + config: M2M100Config + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + self.embed_positions = M2M100SinusoidalPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + self.padding_idx, + ) + self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.init_weights() + + # Copied from transformers.models.mbart.modeling_mbart.MBartEncoder.forward with MBart->M2M100 + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.M2M100Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_ids, inputs_embeds) + + hidden_states = inputs_embeds + embed_pos + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class M2M100Decoder(M2M100PreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`M2M100DecoderLayer` + + Args: + config: M2M100Config + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = M2M100SinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + self.padding_idx, + ) + self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.init_weights() + + # Copied from transformers.models.mbart.modeling_mbart.MBartDecoder.forward with MBart->M2M100 + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.M2M100Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last + :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of + shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, + sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None and combined_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = combined_attention_mask + _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) + + hidden_states = inputs_embeds + positions + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + combined_attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare M2M100 Model outputting raw hidden-states without any specific head on top.", + M2M_100_START_DOCSTRING, +) +class M2M100Model(M2M100PreTrainedModel): + def __init__(self, config: M2M100Config): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = M2M100Encoder(config, self.shared) + self.decoder = M2M100Decoder(config, self.shared) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="facebook/m2m100_418M", + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The M2M100 Model with a language modeling head. Can be used for summarization.", M2M_100_START_DOCSTRING +) +class M2M100ForConditionalGeneration(M2M100PreTrainedModel): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = [ + r"final_logits_bias", + r"encoder\.version", + r"decoder\.version", + r"lm_head\.weight", + ] + + def __init__(self, config: M2M100Config): + super().__init__(config) + self.model = M2M100Model(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + self.init_weights() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens) + self._resize_final_logits_bias(new_num_tokens) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(M2M_100_GENERATION_EXAMPLE) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. + + Returns: + + Example:: + + >>> from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration + + >>> model = M2M100ForConditionalGeneration.from_pretrained('facebook/m2m100_418M') + >>> tokenizer = M2M100Tokenizer.from_pretrained('facebook/m2m100_418M') + + >>> text_to_translate = "Life is like a box of chocolates" + >>> model_inputs = tokenizer(text_to_translate, return_tensors='pt') + + >>> # translate to French + >>> gen_tokens = model.generate( **model_inputs, forced_bos_token_id=tok.get_lang_id("fr")) + >>> print(tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_inputs_for_generation( + self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/m2m_100/tokenization_m2m_100.py b/src/transformers/models/m2m_100/tokenization_m2m_100.py new file mode 100644 index 00000000000..cd449fa84a2 --- /dev/null +++ b/src/transformers/models/m2m_100/tokenization_m2m_100.py @@ -0,0 +1,334 @@ +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for M2M100.""" +import json +from contextlib import contextmanager +from pathlib import Path +from shutil import copyfile +from typing import Dict, List, Optional, Tuple, Union + +import sentencepiece + +from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "spm_file": "sentencepiece.bpe.model", + "tokenizer_config_file": "tokenizer_config.json", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "facebook/m2m100_418M": "https://huggingface.co/facebook/m2m100_418M/resolve/main/vocab.json", + }, + "spm_file": { + "facebook/m2m100_418M": "https://huggingface.co/facebook/m2m100_418M/resolve/main/sentencepiece.bpe.model", + }, + "tokenizer_config_file": { + "facebook/m2m100_418M": "https://huggingface.co/facebook/m2m100_418M/resolve/main/tokenizer_config.json", + }, +} + +ALL_M2M100_MODELS = ["facebook/m2m100_418M", "facebook/m2m100_1.2B"] +SPM_URL = "https://huggingface.co/facebook/m2m100_418M/resolve/main/sentence.bpe.model" + +# fmt: off +FAIRSEQ_LANGUAGE_CODES = ["af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb", "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig", "ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln", "lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no", "ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk", "ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"] +# fmt: on + + +class M2M100Tokenizer(PreTrainedTokenizer): + """ + Construct an M2M100 tokenizer. Based on `SentencePiece `__. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + spm_file (:obj:`str`): + Path to `SentencePiece `__ file (generally has a .spm extension) + that contains the vocabulary. + src_lang (:obj:`str`, `optional`): + A string representing the source language. + tgt_lang (:obj:`str`, `optional`): + A string representing the target language. + eos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The end of sequence token. + sep_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for padding, for example when batching sequences of different lengths. + + Examples:: + + >>> from transformers import M2M100Tokenizer + >>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M, src_lang="en", tgt_lang="ro") + >>> src_text = " UN Chief Says There Is No Military Solution in Syria" + >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria" + >>> model_inputs = tokenizer(src_text, return_tensors="pt") + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(tgt_text, return_tensors="pt").input_ids + >>> # model(**model_inputs, labels=labels) should work + """ + + vocab_files_names = VOCAB_FILES_NAMES + max_model_input_sizes = {m: 1024 for m in ALL_M2M100_MODELS} + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + spm_file, + src_lang=None, + tgt_lang=None, + bos_token="", + eos_token="", + sep_token="", + pad_token="", + unk_token="", + **kwargs, + ): + super().__init__( + src_lang=src_lang, + tgt_lang=tgt_lang, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + unk_token=unk_token, + pad_token=pad_token, + **kwargs, + ) + + self.vocab_file = vocab_file + self.encoder = load_json(vocab_file) + self.decoder = {v: k for k, v in self.encoder.items()} + self.spm_file = spm_file + self.sp_model = load_spm(spm_file) + + self.encoder_size = len(self.encoder) + + self.lang_code_to_token = {lang_code: f"__{lang_code}__" for lang_code in FAIRSEQ_LANGUAGE_CODES} + + self.lang_token_to_id = { + self.get_lang_token(lang_code): self.encoder_size + i for i, lang_code in enumerate(FAIRSEQ_LANGUAGE_CODES) + } + self.lang_code_to_id = {lang_code: self.encoder_size + i for i, lang_code in enumerate(FAIRSEQ_LANGUAGE_CODES)} + self.id_to_lang_token = {v: k for k, v in self.lang_token_to_id.items()} + self._additional_special_tokens = list(self.lang_token_to_id.keys()) + + self._src_lang = src_lang if src_lang is not None else "en" + self.tgt_lang = tgt_lang + self.cur_lang_id = self.get_lang_id(self._src_lang) + self.set_src_lang_special_tokens(self._src_lang) + + self.num_madeup_words = 8 + + @property + def vocab_size(self) -> int: + return len(self.encoder) + len(self.lang_token_to_id) + self.num_madeup_words + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.EncodeAsPieces(text) + + def _convert_token_to_id(self, token): + if token in self.lang_token_to_id: + return self.lang_token_to_id[token] + return self.encoder.get(token, self.encoder[self.unk_token]) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) in a token (str) using the decoder.""" + if index in self.id_to_lang_token: + return self.id_to_lang_token[index] + return self.decoder.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An MBART sequence has the following format, where ``X`` represents the sequence: + + - ``input_ids`` (for encoder) ``X [eos, src_lang_code]`` + - ``decoder_input_ids``: (for decoder) ``X [eos, tgt_lang_code]`` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def get_vocab(self) -> Dict: + vocab = self.encoder.copy() + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self) -> Dict: + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d: Dict) -> None: + self.__dict__ = d + self.sp_model = load_spm(self.spm_file) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + save_dir = Path(save_directory) + assert save_dir.is_dir(), f"{save_directory} should be a directory" + vocab_save_path = save_dir / ( + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"] + ) + spm_save_path = save_dir / ( + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["spm_file"] + ) + + save_json(self.encoder, vocab_save_path) + + if not spm_save_path.exists(): + copyfile(self.spm_file, spm_save_path) + + return (str(vocab_save_path), str(spm_save_path)) + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "en", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "ro", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self.src_lang) + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ + self.set_tgt_lang_special_tokens(self.tgt_lang) + yield + self.set_src_lang_special_tokens(self.src_lang) + + def set_src_lang_special_tokens(self, src_lang: str) -> None: + """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" + lang_token = self.get_lang_token(src_lang) + self.cur_lang_id = self.lang_token_to_id[lang_token] + self.prefix_tokens = [self.cur_lang_id] + self.suffix_tokens = [self.eos_token_id] + + def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None: + """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code].""" + lang_token = self.get_lang_token(tgt_lang) + self.cur_lang_id = self.lang_token_to_id[lang_token] + self.prefix_tokens = [self.cur_lang_id] + self.suffix_tokens = [self.eos_token_id] + + def get_lang_token(self, lang: str) -> str: + return self.lang_code_to_token[lang] + + def get_lang_id(self, lang: str) -> int: + lang_token = self.get_lang_token(lang) + return self.lang_token_to_id[lang_token] + + +def load_spm(path: str) -> sentencepiece.SentencePieceProcessor: + spm = sentencepiece.SentencePieceProcessor() + spm.Load(str(path)) + return spm + + +def load_json(path: str) -> Union[Dict, List]: + with open(path, "r") as f: + return json.load(f) + + +def save_json(data, path: str) -> None: + with open(path, "w") as f: + json.dump(data, f, indent=2) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e0c80e15786..fb782d65059 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1593,6 +1593,27 @@ class LxmertXLayer: requires_pytorch(self) +M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class M2M100ForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class M2M100Model: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + class MarianForCausalLM: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_modeling_m2m_100.py b/tests/test_modeling_m2m_100.py new file mode 100644 index 00000000000..0e02ebdc189 --- /dev/null +++ b/tests/test_modeling_m2m_100.py @@ -0,0 +1,353 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch M2M100 model. """ + + +import copy +import tempfile +import unittest + +from transformers import is_torch_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin +from .test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Model, M2M100Tokenizer + from transformers.models.m2m_100.modeling_m2m_100 import M2M100Decoder, M2M100Encoder + + +def prepare_m2m_100_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.ne(config.pad_token_id) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + return { + "input_ids": input_ids, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": attention_mask, + } + + +@require_torch +class M2M100ModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_labels=False, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="relu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( + 3, + ) + input_ids[:, -1] = self.eos_token_id # Eos Token + + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + config = M2M100Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + ) + inputs_dict = prepare_m2m_100_inputs_dict(config, input_ids, decoder_input_ids) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): + model = M2M100Model(config=config).get_decoder().to(torch_device).eval() + input_ids = inputs_dict["input_ids"] + attention_mask = inputs_dict["attention_mask"] + + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + + def check_encoder_decoder_model_standalone(self, config, inputs_dict): + model = M2M100Model(config=config).to(torch_device).eval() + outputs = model(**inputs_dict) + + encoder_last_hidden_state = outputs.encoder_last_hidden_state + last_hidden_state = outputs.last_hidden_state + + with tempfile.TemporaryDirectory() as tmpdirname: + encoder = model.get_encoder() + encoder.save_pretrained(tmpdirname) + encoder = M2M100Encoder.from_pretrained(tmpdirname).to(torch_device) + + encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[ + 0 + ] + + self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) + + with tempfile.TemporaryDirectory() as tmpdirname: + decoder = model.get_decoder() + decoder.save_pretrained(tmpdirname) + decoder = M2M100Decoder.from_pretrained(tmpdirname).to(torch_device) + + last_hidden_state_2 = decoder( + input_ids=inputs_dict["decoder_input_ids"], + attention_mask=inputs_dict["decoder_attention_mask"], + encoder_hidden_states=encoder_last_hidden_state, + encoder_attention_mask=inputs_dict["attention_mask"], + )[0] + + self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3) + + +@require_torch +class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + M2M100Model, + M2M100ForConditionalGeneration, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else () + is_encoder_decoder = True + test_pruning = False + test_head_masking = False + test_missing_keys = False + + def setUp(self): + self.model_tester = M2M100ModelTester(self) + self.config_tester = ConfigTester(self, config_class=M2M100Config) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_save_load_strict(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) + self.assertEqual(info["missing_keys"], []) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_encoder_decoder_model_standalone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (M2M100Model, M2M100ForConditionalGeneration): + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] + + def test_generate_fp16(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + model = M2M100ForConditionalGeneration(config).eval().to(torch_device) + if torch_device == "cuda": + model.half() + model.generate(input_ids, attention_mask=attention_mask) + model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + + +def _long_tensor(tok_lst): + return torch.tensor(tok_lst, dtype=torch.long, device=torch_device) + + +TOLERANCE = 1e-4 + + +@require_torch +@require_sentencepiece +@require_tokenizers +@slow +class M2M100ModelIntegrationTests(unittest.TestCase): + @cached_property + def default_tokenizer(self): + return M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") + + def test_inference_no_head(self): + model = M2M100Model.from_pretrained("facebook/m2m100_418M").to(torch_device) + input_ids = _long_tensor([[128028, 98, 12, 30527, 2732, 159, 7755, 61904, 39144, 38, 2]]) + decoder_input_ids = _long_tensor([[2, 128028, 98, 12, 30527, 2732, 159, 7755, 61904, 39144, 38]]) + inputs_dict = prepare_m2m_100_inputs_dict(model.config, input_ids, decoder_input_ids) + with torch.no_grad(): + output = model(**inputs_dict)[0] + expected_shape = torch.Size((1, 11, 1024)) + self.assertEqual(output.shape, expected_shape) + # change to expected output here + expected_slice = torch.tensor( + [[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]], device=torch_device + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) + + def test_inference_head(self): + model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(torch_device) + + # change to intended input + input_ids = _long_tensor([[128028, 98, 12, 30527, 2732, 159, 7755, 61904, 39144, 38, 2]]) + decoder_input_ids = _long_tensor([[2, 128028, 98, 12, 30527, 2732, 159, 7755, 61904, 39144, 38]]) + inputs_dict = prepare_m2m_100_inputs_dict(model.config, input_ids, decoder_input_ids) + with torch.no_grad(): + output = model(**inputs_dict)[0] + expected_shape = torch.Size((1, 11, model.config.vocab_size)) + self.assertEqual(output.shape, expected_shape) + # change to expected output here + expected_slice = torch.tensor( + [[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]], device=torch_device + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) + + def test_seq_to_seq_generation(self): + model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(torch_device) + tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en") + + src_fr = [ + "L'affaire NSA souligne l'absence totale de débat sur le renseignement", + "Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.", + "Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de l'ampleur de la surveillance américaine sur l'ensemble des communications en France.", + ] + + # The below article tests that we don't add any hypotheses outside of the top n_beams + dct = tokenizer(src_fr, padding=True, return_tensors="pt") + + hypotheses_batch = model.generate( + input_ids=dct["input_ids"].to(torch_device), + attention_mask=dct["attention_mask"].to(torch_device), + num_beams=5, + forced_bos_token_id=tokenizer.get_lang_id("en"), + ) + + expected_en = [ + "The NSA case highlights the total absence of intelligence debate", + "I think there are two levels of response from the French government.", + "When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S. Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all communications in France.", + ] + + generated = tokenizer.batch_decode( + hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True + ) + assert generated == expected_en diff --git a/tests/test_tokenization_m2m_100.py b/tests/test_tokenization_m2m_100.py new file mode 100644 index 00000000000..649d471deb1 --- /dev/null +++ b/tests/test_tokenization_m2m_100.py @@ -0,0 +1,193 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path +from shutil import copyfile + +from transformers import M2M100Tokenizer, is_torch_available +from transformers.file_utils import is_sentencepiece_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch + + +if is_sentencepiece_available(): + from transformers.models.m2m_100.tokenization_m2m_100 import save_json, VOCAB_FILES_NAMES + +from .test_tokenization_common import TokenizerTesterMixin + + +if is_sentencepiece_available(): + SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") + + +if is_torch_available(): + from transformers.models.m2m_100.modeling_m2m_100 import shift_tokens_right + +EN_CODE = 128022 +FR_CODE = 128028 + + +@require_sentencepiece +class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = M2M100Tokenizer + test_rust_tokenizer = False + test_seq2seq = False + + def setUp(self): + super().setUp() + + vocab = ["", "", "▁This", "▁is", "▁a", "▁t", "est", "\u0120", ""] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + save_dir = Path(self.tmpdirname) + save_json(vocab_tokens, save_dir / VOCAB_FILES_NAMES["vocab_file"]) + if not (save_dir / VOCAB_FILES_NAMES["spm_file"]).exists(): + copyfile(SAMPLE_SP, save_dir / VOCAB_FILES_NAMES["spm_file"]) + + tokenizer = M2M100Tokenizer.from_pretrained(self.tmpdirname) + tokenizer.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return M2M100Tokenizer.from_pretrained(self.tmpdirname, **kwargs) + + def get_input_output_texts(self, tokenizer): + return ( + "This is a test", + "This is a test", + ) + + @unittest.skip("Skip this test while all models are still to be uploaded.") + def test_pretrained_model_lists(self): + pass + + def test_full_tokenizer(self): + tokenizer = self.get_tokenizer() + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), + [2, 3, 4, 5, 6], + ) + + back_tokens = tokenizer.convert_ids_to_tokens([2, 3, 4, 5, 6]) + self.assertListEqual(back_tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + text = tokenizer.convert_tokens_to_string(tokens) + self.assertEqual(text, "This is a test") + + +@require_torch +@require_sentencepiece +@require_tokenizers +class M2M100TokenizerIntegrationTest(unittest.TestCase): + checkpoint_name = "facebook/m2m100_418M" + src_text = [ + "In my opinion, there are two levels of response from the French government.", + "NSA Affair Emphasizes Complete Lack of Debate on Intelligence", + ] + tgt_text = [ + "Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.", + "L'affaire NSA souligne l'absence totale de débat sur le renseignement", + ] + + # fmt: off + expected_src_tokens = [EN_CODE, 593, 1949, 115781, 4, 71586, 4234, 60633, 126233, 432, 123808, 15592, 1197, 117132, 120618, 5, 2] + # fmt: on + + @classmethod + def setUpClass(cls): + cls.tokenizer: M2M100Tokenizer = M2M100Tokenizer.from_pretrained( + cls.checkpoint_name, src_lang="en", tgt_lang="fr" + ) + cls.pad_token_id = 1 + return cls + + def check_language_codes(self): + self.assertEqual(self.tokenizer.get_lang_id("ar"), 128006) + self.assertEqual(self.tokenizer.get_lang_id("en"), 128022) + self.assertEqual(self.tokenizer.get_lang_id("ro"), 128076) + self.assertEqual(self.tokenizer.get_lang_id("mr"), 128063) + + def test_tokenizer_batch_encode_plus(self): + self.tokenizer.src_lang = "en" + ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] + self.assertListEqual(self.expected_src_tokens, ids) + + def test_tokenizer_decode_ignores_language_codes(self): + self.assertIn(FR_CODE, self.tokenizer.all_special_ids) + # fmt: off + generated_ids = [FR_CODE, 5364, 82, 8642, 4, 294, 47, 8, 14028, 136, 3286, 9706, 6, 90797, 6, 144012, 162, 88128, 30061, 5, 2] + # fmt: on + result = self.tokenizer.decode(generated_ids, skip_special_tokens=True) + expected_french = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True) + self.assertEqual(result, expected_french) + self.assertNotIn(self.tokenizer.eos_token, result) + + def test_special_tokens_unaffacted_by_save_load(self): + tmpdirname = tempfile.mkdtemp() + original_special_tokens = self.tokenizer.lang_token_to_id + self.tokenizer.save_pretrained(tmpdirname) + new_tok = M2M100Tokenizer.from_pretrained(tmpdirname) + self.assertDictEqual(new_tok.lang_token_to_id, original_special_tokens) + + @require_torch + def test_batch_fairseq_parity(self): + self.tokenizer.src_lang = "en" + self.tokenizer.tgt_lang = "fr" + + batch = self.tokenizer(self.src_text, padding=True, return_tensors="pt") + with self.tokenizer.as_target_tokenizer(): + batch["labels"] = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt").input_ids + + batch["decoder_input_ids"] = shift_tokens_right( + batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.eos_token_id + ) + + for k in batch: + batch[k] = batch[k].tolist() + # batch = {k: v.tolist() for k,v in batch.items()} + # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4 + # batch.decoder_inputs_ids[0][0] == + assert batch.input_ids[1][0] == EN_CODE + assert batch.input_ids[1][-1] == 2 + assert batch.labels[1][0] == FR_CODE + assert batch.labels[1][-1] == 2 + assert batch.decoder_input_ids[1][:2] == [2, FR_CODE] + + @require_torch + def test_src_lang_setter(self): + self.tokenizer.src_lang = "mr" + self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")]) + self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id]) + + self.tokenizer.src_lang = "zh" + self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")]) + self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id]) + + @require_torch + def test_as_target_tokenizer(self): + self.tokenizer.tgt_lang = "mr" + with self.tokenizer.as_target_tokenizer(): + self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")]) + self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id]) + self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)]) + + self.tokenizer.tgt_lang = "zh" + with self.tokenizer.as_target_tokenizer(): + self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")]) + self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id]) + self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)]) diff --git a/utils/check_repo.py b/utils/check_repo.py index c8881baa651..0db06c9e792 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -30,6 +30,8 @@ PATH_TO_DOC = "docs/source" # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = [ # models to ignore for not tested + "M2M100Encoder", # Building part of bigger (tested) model. + "M2M100Decoder", # Building part of bigger (tested) model. "LEDEncoder", # Building part of bigger (tested) model. "LEDDecoder", # Building part of bigger (tested) model. "BartDecoderWrapper", # Building part of bigger (tested) model. @@ -75,6 +77,8 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = [ # models to ignore for model xxx mapping + "M2M100Encoder", + "M2M100Decoder", "LEDEncoder", "LEDDecoder", "BartDecoder",