mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[ported model] FSMT (FairSeq MachineTranslation) (#6940)
* ready for PR
* cleanup
* correct FSMT_PRETRAINED_MODEL_ARCHIVE_LIST
* fix
* perfectionism
* revert change from another PR
* odd, already committed this one
* non-interactive upload workaround
* backup the failed experiment
* store langs in config
* workaround for localizing model path
* doc clean up as in https://github.com/huggingface/transformers/pull/6956
* style
* back out debug mode
* document: run_eval.py --num_beams 10
* remove unneeded constant
* typo
* re-use bart's Attention
* re-use EncoderLayer, DecoderLayer from bart
* refactor
* send to cuda and fp16
* cleanup
* revert (moved to another PR)
* better error message
* document run_eval --num_beams
* solve the problem of tokenizer finding the right files when model is local
* polish, remove hardcoded config
* add a note that the file is autogenerated to avoid losing changes
* prep for org change, remove unneeded code
* switch to model4.pt, update scores
* s/python/bash/
* missing init (but doesn't impact the finetuned model)
* cleanup
* major refactor (reuse-bart)
* new model, new expected weights
* cleanup
* cleanup
* full link
* fix model type
* merge porting notes
* style
* cleanup
* have to create a DecoderConfig object to handle vocab_size properly
* doc fix
* add note (not a public class)
* parametrize
* - add bleu scores integration tests
* skip test if sacrebleu is not installed
* cache heavy models/tokenizers
* some tweaks
* remove tokens that aren't used
* more purging
* simplify code
* switch to using decoder_start_token_id
* add doc
* Revert "major refactor (reuse-bart)"
This reverts commit 226dad15ca
.
* decouple from bart
* remove unused code #1
* remove unused code #2
* remove unused code #3
* update instructions
* clean up
* move bleu eval to examples
* check import only once
* move data+gen script into files
* reuse via import
* take less space
* add prepare_seq2seq_batch (auto-tested)
* cleanup
* recode test to use json instead of yaml
* ignore keys not needed
* use the new -y in transformers-cli upload -y
* [xlm tok] config dict: fix str into int to match definition (#7034)
* [s2s] --eval_max_generate_length (#7018)
* Fix CI with change of name of nlp (#7054)
* nlp -> datasets
* More nlp -> datasets
* Woopsie
* More nlp -> datasets
* One last
* extending to support allen_nlp wmt models
- allow a specific checkpoint file to be passed
- more arg settings
- scripts for allen_nlp models
* sync with changes
* s/fsmt-wmt/wmt/ in model names
* s/fsmt-wmt/wmt/ in model names (p2)
* s/fsmt-wmt/wmt/ in model names (p3)
* switch to a better checkpoint
* typo
* make non-optional args such - adjust tests where possible or skip when there is no other choice
* consistency
* style
* adjust header
* cards moved (model rename)
* use best custom hparams
* update info
* remove old cards
* cleanup
* s/stas/facebook/
* update scores
* s/allen_nlp/allenai/
* url maps aren't needed
* typo
* move all the doc / build /eval generators to their own scripts
* cleanup
* Apply suggestions from code review
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
* Apply suggestions from code review
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
* fix indent
* duplicated line
* style
* use the correct add_start_docstrings
* oops
* resizing can't be done with the core approach, due to 2 dicts
* check that the arg is a list
* style
* style
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
492bb6aa48
commit
1eeb206bef
@ -3,9 +3,9 @@ Transformers
|
||||
|
||||
State-of-the-art Natural Language Processing for Pytorch and TensorFlow 2.0.
|
||||
|
||||
🤗 Transformers (formerly known as `pytorch-transformers` and `pytorch-pretrained-bert`) provides general-purpose
|
||||
architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural Language Understanding (NLU) and Natural
|
||||
Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
|
||||
🤗 Transformers (formerly known as `pytorch-transformers` and `pytorch-pretrained-bert`) provides general-purpose
|
||||
architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural Language Understanding (NLU) and Natural
|
||||
Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
|
||||
TensorFlow 2.0 and PyTorch.
|
||||
|
||||
This is the documentation of our repository `transformers <https://github.com/huggingface/transformers>`_.
|
||||
@ -127,7 +127,7 @@ conversion utilities for the following models:
|
||||
23. `Pegasus <https://github.com/google-research/pegasus>`_ (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization
|
||||
<https://arxiv.org/abs/1912.08777>`_ by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||
24. `MBart <https://github.com/pytorch/fairseq/tree/master/examples/mbart>`_ (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation <https://arxiv.org/abs/2001.08210>`_ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov,
|
||||
Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||
Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||
25. `LXMERT <https://github.com/airsplay/lxmert>`_ (from UNC Chapel Hill) released with the paper `LXMERT: Learning
|
||||
Cross-Modality Encoder Representations from Transformers for Open-Domain Question
|
||||
Answering <https://arxiv.org/abs/1908.07490>`_ by Hao Tan and Mohit Bansal.
|
||||
@ -223,6 +223,7 @@ conversion utilities for the following models:
|
||||
model_doc/dpr
|
||||
model_doc/pegasus
|
||||
model_doc/mbart
|
||||
model_doc/fsmt
|
||||
model_doc/funnel
|
||||
model_doc/lxmert
|
||||
model_doc/bertgeneration
|
||||
|
49
docs/source/model_doc/fsmt.rst
Normal file
49
docs/source/model_doc/fsmt.rst
Normal file
@ -0,0 +1,49 @@
|
||||
FSMT
|
||||
----------------------------------------------------
|
||||
**DISCLAIMER:** If you see something strange,
|
||||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
||||
@stas00.
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
FSMT (FairSeq MachineTranslation) models were introduced in "Facebook FAIR's WMT19 News Translation Task Submission" <this paper <https://arxiv.org/abs/1907.06616>__ by Nathan Ng, Kyra Yee, Alexei Baevski, Myle Ott, Michael Auli, Sergey Edunov.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
This paper describes Facebook FAIR's submission to the WMT19 shared news translation task. We participate in two language pairs and four language directions, English <-> German and English <-> Russian. Following our submission from last year, our baseline systems are large BPE-based transformer models trained with the Fairseq sequence modeling toolkit which rely on sampled back-translations. This year we experiment with different bitext data filtering schemes, as well as with adding filtered back-translated data. We also ensemble and fine-tune our models on domain-specific data, then decode using noisy channel model reranking. Our submissions are ranked first in all four directions of the human evaluation campaign. On En->De, our system significantly outperforms other systems as well as human translations. This system improves upon our WMT'18 submission by 4.5 BLEU points.
|
||||
|
||||
The original code can be found here <https://github.com/pytorch/fairseq/tree/master/examples/wmt19>__.
|
||||
|
||||
Implementation Notes
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- FSMT uses source and target vocab pair, that aren't combined into one. It doesn't share embed tokens either. Its tokenizer is very similar to `XLMTokenizer` and the main model is derived from `BartModel`.
|
||||
|
||||
|
||||
FSMTForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FSMTForConditionalGeneration
|
||||
:members: forward
|
||||
|
||||
|
||||
FSMTConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FSMTConfig
|
||||
:members:
|
||||
|
||||
|
||||
FSMTTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FSMTTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
FSMTModel
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FSMTModel
|
||||
:members: forward
|
33
examples/seq2seq/test_data/fsmt/build-eval-data.py
Executable file
33
examples/seq2seq/test_data/fsmt/build-eval-data.py
Executable file
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import io
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
|
||||
pairs = [
|
||||
["en", "ru"],
|
||||
["ru", "en"],
|
||||
["en", "de"],
|
||||
["de", "en"],
|
||||
]
|
||||
|
||||
n_objs = 8
|
||||
|
||||
|
||||
def get_all_data(pairs, n_objs):
|
||||
text = {}
|
||||
for src, tgt in pairs:
|
||||
pair = f"{src}-{tgt}"
|
||||
cmd = f"sacrebleu -t wmt19 -l {pair} --echo src".split()
|
||||
src_lines = subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode("utf-8").splitlines()
|
||||
cmd = f"sacrebleu -t wmt19 -l {pair} --echo ref".split()
|
||||
tgt_lines = subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode("utf-8").splitlines()
|
||||
text[pair] = {"src": src_lines[:n_objs], "tgt": tgt_lines[:n_objs]}
|
||||
return text
|
||||
|
||||
|
||||
text = get_all_data(pairs, n_objs)
|
||||
filename = "./fsmt_val_data.json"
|
||||
with io.open(filename, "w", encoding="utf-8") as f:
|
||||
bleu_data = json.dump(text, f, indent=2, ensure_ascii=False)
|
90
examples/seq2seq/test_data/fsmt/fsmt_val_data.json
Normal file
90
examples/seq2seq/test_data/fsmt/fsmt_val_data.json
Normal file
@ -0,0 +1,90 @@
|
||||
{
|
||||
"en-ru": {
|
||||
"src": [
|
||||
"Welsh AMs worried about 'looking like muppets'",
|
||||
"There is consternation among some AMs at a suggestion their title should change to MWPs (Member of the Welsh Parliament).",
|
||||
"It has arisen because of plans to change the name of the assembly to the Welsh Parliament.",
|
||||
"AMs across the political spectrum are worried it could invite ridicule.",
|
||||
"One Labour AM said his group was concerned \"it rhymes with Twp and Pwp.\"",
|
||||
"For readers outside of Wales: In Welsh twp means daft and pwp means poo.",
|
||||
"A Plaid AM said the group as a whole was \"not happy\" and has suggested alternatives.",
|
||||
"A Welsh Conservative said his group was \"open minded\" about the name change, but noted it was a short verbal hop from MWP to Muppet."
|
||||
],
|
||||
"tgt": [
|
||||
"Члены Национальной ассамблеи Уэльса обеспокоены, что \"выглядят как куклы\"",
|
||||
"Некоторые члены Национальной ассамблеи Уэльса в ужасе от предложения о том, что их наименование должно измениться на MPW (члены Парламента Уэльса).",
|
||||
"Этот вопрос был поднят в связи с планами по переименованию ассамблеи в Парламент Уэльса.",
|
||||
"Члены Национальной ассамблеи Уэльса всего политического спектра обеспокоены, что это может породить насмешки.",
|
||||
"Один из лейбористских членов Национальной ассамблеи Уэльса сказал, что его партия обеспокоена тем, что \"это рифмуется с Twp и Pwp\".",
|
||||
"Для читателей за предлами Уэльса: по-валлийски twp означает \"глупый\", а pwp означает \"какашка\".",
|
||||
"Член Национальной ассамблеи от Плайд сказал, что эта партия в целом \"не счастлива\" и предложил альтернативы.",
|
||||
"Представитель Консервативной партии Уэльса сказал, что его партия \"открыта\" к переименованию, но отметил, что между WMP и Muppet небольшая разница в произношении."
|
||||
]
|
||||
},
|
||||
"ru-en": {
|
||||
"src": [
|
||||
"Названо число готовящихся к отправке в Донбасс новобранцев из Украины",
|
||||
"Официальный представитель Народной милиции самопровозглашенной Луганской Народной Республики (ЛНР) Андрей Марочко заявил, что зимой 2018-2019 года Украина направит в Донбасс не менее 3 тыс. новобранцев.",
|
||||
"По его словам, таким образом Киев планирует \"хоть как-то доукомплектовать подразделения\".",
|
||||
"\"Нежелание граждан Украины проходить службу в рядах ВС Украины, массовые увольнения привели к низкой укомплектованности подразделений\", - рассказал Марочко, которого цитирует \"РИА Новости\".",
|
||||
"Он также не исключил, что реальные цифры призванных в армию украинцев могут быть увеличены в случае необходимости.",
|
||||
"В 2014-2017 годах Киев начал так называемую антитеррористическую операцию (АТО), которую позже сменили на операцию объединенных сил (ООС).",
|
||||
"Предполагалось, что эта мера приведет к усилению роли украинских силовиков в урегулировании ситуации.",
|
||||
"В конце августа 2018 года ситуация в Донбассе обострилась из-за убийства главы ДНР Александра Захарченко."
|
||||
],
|
||||
"tgt": [
|
||||
"The number of new Ukrainian recruits ready to go to Donbass has become public",
|
||||
"Official representative of the peoples’ militia of the self-proclaimed Lugansk People’s Republic Andrey Marochko claimed that Ukrainian will send at least 3 thousand new recruits to Donbass in winter 2018-2019.",
|
||||
"This is how Kyiv tries “at least somehow to staff the units,” he said.",
|
||||
"“The unwillingness of Ukrainian citizens to serve in the Ukraine’s military forces, mass resignments lead to low understaffing,” said Marochko cited by RIA Novosti.",
|
||||
"Also, he doesn’t exclude that the real numbers of conscripts in the Ukrainian army can be raised is necessary.",
|
||||
"In 2014-2017, Kyiv started so-called antiterrorist operation, that ws later changed to the united forces operation.",
|
||||
"This measure was supposed to strengthen the role of the Ukrainian military in settling the situation.",
|
||||
"In the late August 2018, the situation in Donbass escalated as the DNR head Aleksandr Zakharchenko was killed."
|
||||
]
|
||||
},
|
||||
"en-de": {
|
||||
"src": [
|
||||
"Welsh AMs worried about 'looking like muppets'",
|
||||
"There is consternation among some AMs at a suggestion their title should change to MWPs (Member of the Welsh Parliament).",
|
||||
"It has arisen because of plans to change the name of the assembly to the Welsh Parliament.",
|
||||
"AMs across the political spectrum are worried it could invite ridicule.",
|
||||
"One Labour AM said his group was concerned \"it rhymes with Twp and Pwp.\"",
|
||||
"For readers outside of Wales: In Welsh twp means daft and pwp means poo.",
|
||||
"A Plaid AM said the group as a whole was \"not happy\" and has suggested alternatives.",
|
||||
"A Welsh Conservative said his group was \"open minded\" about the name change, but noted it was a short verbal hop from MWP to Muppet."
|
||||
],
|
||||
"tgt": [
|
||||
"Walisische Ageordnete sorgen sich \"wie Dödel auszusehen\"",
|
||||
"Es herrscht Bestürzung unter einigen Mitgliedern der Versammlung über einen Vorschlag, der ihren Titel zu MWPs (Mitglied der walisischen Parlament) ändern soll.",
|
||||
"Der Grund dafür waren Pläne, den Namen der Nationalversammlung in Walisisches Parlament zu ändern.",
|
||||
"Mitglieder aller Parteien der Nationalversammlung haben Bedenken, dass sie sich dadurch Spott aussetzen könnten.",
|
||||
"Ein Labour-Abgeordneter sagte, dass seine Gruppe \"sich mit Twp und Pwp reimt\".",
|
||||
"Hinweis für den Leser: „twp“ im Walisischen bedeutet „bescheuert“ und „pwp“ bedeutet „Kacke“.",
|
||||
"Ein Versammlungsmitglied von Plaid Cymru sagte, die Gruppe als Ganzes sei \"nicht glücklich\" und hat Alternativen vorgeschlagen.",
|
||||
"Ein walisischer Konservativer sagte, seine Gruppe wäre „offen“ für eine Namensänderung, wies aber darauf hin, dass es von „MWP“ (Mitglied des Walisischen Parlaments) nur ein kurzer verbaler Sprung zu „Muppet“ ist."
|
||||
]
|
||||
},
|
||||
"de-en": {
|
||||
"src": [
|
||||
"Schöne Münchnerin 2018: Schöne Münchnerin 2018 in Hvar: Neun Dates",
|
||||
"Von az, aktualisiert am 04.05.2018 um 11:11",
|
||||
"Ja, sie will...",
|
||||
"\"Schöne Münchnerin\" 2018 werden!",
|
||||
"Am Nachmittag wartet erneut eine Überraschung auf unsere Kandidatinnen: sie werden das romantische Candlelight-Shooting vor der MY SOLARIS nicht alleine bestreiten, sondern an der Seite von Male-Model Fabian!",
|
||||
"Hvar - Flirten, kokettieren, verführen - keine einfachen Aufgaben für unsere Mädchen.",
|
||||
"Insbesondere dann, wenn in Deutschland ein Freund wartet.",
|
||||
"Dennoch liefern die neun \"Schöne Münchnerin\"-Kandidatinnen beim Shooting mit People-Fotograf Tuan ab und trotzen Wind, Gischt und Regen wie echte Profis."
|
||||
],
|
||||
"tgt": [
|
||||
"The Beauty of Munich 2018: the Beauty of Munich 2018 in Hvar: Nine dates",
|
||||
"From A-Z, updated on 04/05/2018 at 11:11",
|
||||
"Yes, she wants to...",
|
||||
"to become \"The Beauty of Munich\" in 2018!",
|
||||
"In the afternoon there is another surprise waiting for our contestants: they will be competing for the romantic candlelight photo shoot at MY SOLARIS not alone, but together with a male-model Fabian!",
|
||||
"Hvar with its flirting, coquetting, and seduction is not an easy task for our girls.",
|
||||
"Especially when there is a boyfriend waiting in Germany.",
|
||||
"Despite dealing with wind, sprays and rain, the nine contestants of \"The Beauty of Munich\" behaved like real professionals at the photo shoot with People-photographer Tuan."
|
||||
]
|
||||
}
|
||||
}
|
77
examples/seq2seq/test_fsmt_bleu_score.py
Normal file
77
examples/seq2seq/test_fsmt_bleu_score.py
Normal file
@ -0,0 +1,77 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Huggingface
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import unittest
|
||||
|
||||
|
||||
try:
|
||||
from .utils import calculate_bleu
|
||||
except ImportError:
|
||||
from utils import calculate_bleu
|
||||
|
||||
import json
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
|
||||
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
|
||||
|
||||
|
||||
filename = get_tests_dir() + "/test_data/fsmt/fsmt_val_data.json"
|
||||
with io.open(filename, "r", encoding="utf-8") as f:
|
||||
bleu_data = json.load(f)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelEvalTester(unittest.TestCase):
|
||||
def get_tokenizer(self, mname):
|
||||
return FSMTTokenizer.from_pretrained(mname)
|
||||
|
||||
def get_model(self, mname):
|
||||
model = FSMTForConditionalGeneration.from_pretrained(mname).to(torch_device)
|
||||
if torch_device == "cuda":
|
||||
model.half()
|
||||
return model
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
["en-ru", 26.0],
|
||||
["ru-en", 22.0],
|
||||
["en-de", 22.0],
|
||||
["de-en", 29.0],
|
||||
]
|
||||
)
|
||||
@slow
|
||||
def test_bleu_scores(self, pair, min_bleu_score):
|
||||
# note: this test is not testing the best performance since it only evals a small batch
|
||||
# but it should be enough to detect a regression in the output quality
|
||||
mname = f"facebook/wmt19-{pair}"
|
||||
tokenizer = self.get_tokenizer(mname)
|
||||
model = self.get_model(mname)
|
||||
|
||||
src_sentences = bleu_data[pair]["src"]
|
||||
tgt_sentences = bleu_data[pair]["tgt"]
|
||||
|
||||
batch = tokenizer(src_sentences, return_tensors="pt", truncation=True, padding="longest").to(torch_device)
|
||||
outputs = model.generate(
|
||||
input_ids=batch.input_ids,
|
||||
num_beams=8,
|
||||
)
|
||||
decoded_sentences = tokenizer.batch_decode(
|
||||
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
scores = calculate_bleu(decoded_sentences, tgt_sentences)
|
||||
print(scores)
|
||||
self.assertGreaterEqual(scores["bleu"], min_bleu_score)
|
94
model_cards/facebook/wmt19-de-en/README.md
Normal file
94
model_cards/facebook/wmt19-de-en/README.md
Normal file
@ -0,0 +1,94 @@
|
||||
|
||||
---
|
||||
|
||||
<!-- This file has been auto-generated by src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py - DO NOT EDIT or your changes will be lost -->
|
||||
|
||||
language: de, en
|
||||
thumbnail:
|
||||
tags:
|
||||
- translation
|
||||
- wmt19
|
||||
license: Apache 2.0
|
||||
datasets:
|
||||
- http://www.statmt.org/wmt19/ ([test-set](http://matrix.statmt.org/test_sets/newstest2019.tgz?1556572561))
|
||||
metrics:
|
||||
- http://www.statmt.org/wmt19/metrics-task.html
|
||||
---
|
||||
|
||||
# FSMT
|
||||
|
||||
## Model description
|
||||
|
||||
This is a ported version of [fairseq wmt19 transformer](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md) for de-en.
|
||||
|
||||
For more details, please see, [Facebook FAIR's WMT19 News Translation Task Submission](https://arxiv.org/abs/1907.06616).
|
||||
|
||||
The abbreviation FSMT stands for FairSeqMachineTranslation
|
||||
|
||||
All four models are available:
|
||||
|
||||
* [wmt19-en-ru](https://huggingface.co/facebook/wmt19-en-ru)
|
||||
* [wmt19-ru-en](https://huggingface.co/facebook/wmt19-ru-en)
|
||||
* [wmt19-en-de](https://huggingface.co/facebook/wmt19-en-de)
|
||||
* [wmt19-de-en](https://huggingface.co/facebook/wmt19-de-en)
|
||||
|
||||
## Intended uses & limitations
|
||||
|
||||
#### How to use
|
||||
|
||||
```python
|
||||
from transformers.tokenization_fsmt import FSMTTokenizer
|
||||
from transformers.modeling_fsmt import FSMTForConditionalGeneration
|
||||
mname = "facebook/wmt19-de-en"
|
||||
tokenizer = FSMTTokenizer.from_pretrained(mname)
|
||||
model = FSMTForConditionalGeneration.from_pretrained(mname)
|
||||
|
||||
input = "Maschinelles Lernen ist großartig, oder?"
|
||||
input_ids = tokenizer.encode(input, return_tensors="pt")
|
||||
outputs = model.generate(input_ids)
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(decoded) # Machine learning is great, isn't it?
|
||||
|
||||
```
|
||||
|
||||
#### Limitations and bias
|
||||
|
||||
- The original (and this ported model) doesn't seem to handle well inputs with repeated sub-phrases, [content gets truncated](https://discuss.huggingface.co/t/issues-with-translating-inputs-containing-repeated-phrases/981)
|
||||
|
||||
## Training data
|
||||
|
||||
Pretrained weights were left identical to the original model released by fairseq. For more details, please, see the [paper](https://arxiv.org/abs/1907.06616).
|
||||
|
||||
## Eval results
|
||||
|
||||
pair | fairseq | transformers
|
||||
-------|---------|----------
|
||||
de-en | [42.3](http://matrix.statmt.org/matrix/output/1902?run_id=6750) | 41.35
|
||||
|
||||
The score is slightly below the score reported by `fairseq`, since `transformers`` currently doesn't support:
|
||||
- model ensemble, therefore the best performing checkpoint was ported (``model4.pt``).
|
||||
- re-ranking
|
||||
|
||||
The score was calculated using this code:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
export PAIR=de-en
|
||||
export DATA_DIR=data/$PAIR
|
||||
export SAVE_DIR=data/$PAIR
|
||||
export BS=8
|
||||
export NUM_BEAMS=15
|
||||
mkdir -p $DATA_DIR
|
||||
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
||||
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
||||
echo $PAIR
|
||||
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
|
||||
```
|
||||
note: fairseq reports using a beam of 50, so you should get a slightly higher score if re-run with `--num_beams 50`.
|
||||
|
||||
|
||||
## TODO
|
||||
|
||||
- port model ensemble (fairseq uses 4 model checkpoints)
|
||||
|
94
model_cards/facebook/wmt19-en-de/README.md
Normal file
94
model_cards/facebook/wmt19-en-de/README.md
Normal file
@ -0,0 +1,94 @@
|
||||
|
||||
---
|
||||
|
||||
<!-- This file has been auto-generated by src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py - DO NOT EDIT or your changes will be lost -->
|
||||
|
||||
language: en, de
|
||||
thumbnail:
|
||||
tags:
|
||||
- translation
|
||||
- wmt19
|
||||
license: Apache 2.0
|
||||
datasets:
|
||||
- http://www.statmt.org/wmt19/ ([test-set](http://matrix.statmt.org/test_sets/newstest2019.tgz?1556572561))
|
||||
metrics:
|
||||
- http://www.statmt.org/wmt19/metrics-task.html
|
||||
---
|
||||
|
||||
# FSMT
|
||||
|
||||
## Model description
|
||||
|
||||
This is a ported version of [fairseq wmt19 transformer](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md) for en-de.
|
||||
|
||||
For more details, please see, [Facebook FAIR's WMT19 News Translation Task Submission](https://arxiv.org/abs/1907.06616).
|
||||
|
||||
The abbreviation FSMT stands for FairSeqMachineTranslation
|
||||
|
||||
All four models are available:
|
||||
|
||||
* [wmt19-en-ru](https://huggingface.co/facebook/wmt19-en-ru)
|
||||
* [wmt19-ru-en](https://huggingface.co/facebook/wmt19-ru-en)
|
||||
* [wmt19-en-de](https://huggingface.co/facebook/wmt19-en-de)
|
||||
* [wmt19-de-en](https://huggingface.co/facebook/wmt19-de-en)
|
||||
|
||||
## Intended uses & limitations
|
||||
|
||||
#### How to use
|
||||
|
||||
```python
|
||||
from transformers.tokenization_fsmt import FSMTTokenizer
|
||||
from transformers.modeling_fsmt import FSMTForConditionalGeneration
|
||||
mname = "facebook/wmt19-en-de"
|
||||
tokenizer = FSMTTokenizer.from_pretrained(mname)
|
||||
model = FSMTForConditionalGeneration.from_pretrained(mname)
|
||||
|
||||
input = "Machine learning is great, isn't it?"
|
||||
input_ids = tokenizer.encode(input, return_tensors="pt")
|
||||
outputs = model.generate(input_ids)
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(decoded) # Maschinelles Lernen ist großartig, oder?
|
||||
|
||||
```
|
||||
|
||||
#### Limitations and bias
|
||||
|
||||
- The original (and this ported model) doesn't seem to handle well inputs with repeated sub-phrases, [content gets truncated](https://discuss.huggingface.co/t/issues-with-translating-inputs-containing-repeated-phrases/981)
|
||||
|
||||
## Training data
|
||||
|
||||
Pretrained weights were left identical to the original model released by fairseq. For more details, please, see the [paper](https://arxiv.org/abs/1907.06616).
|
||||
|
||||
## Eval results
|
||||
|
||||
pair | fairseq | transformers
|
||||
-------|---------|----------
|
||||
en-de | [43.1](http://matrix.statmt.org/matrix/output/1909?run_id=6862) | 42.83
|
||||
|
||||
The score is slightly below the score reported by `fairseq`, since `transformers`` currently doesn't support:
|
||||
- model ensemble, therefore the best performing checkpoint was ported (``model4.pt``).
|
||||
- re-ranking
|
||||
|
||||
The score was calculated using this code:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
export PAIR=en-de
|
||||
export DATA_DIR=data/$PAIR
|
||||
export SAVE_DIR=data/$PAIR
|
||||
export BS=8
|
||||
export NUM_BEAMS=15
|
||||
mkdir -p $DATA_DIR
|
||||
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
||||
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
||||
echo $PAIR
|
||||
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
|
||||
```
|
||||
note: fairseq reports using a beam of 50, so you should get a slightly higher score if re-run with `--num_beams 50`.
|
||||
|
||||
|
||||
## TODO
|
||||
|
||||
- port model ensemble (fairseq uses 4 model checkpoints)
|
||||
|
94
model_cards/facebook/wmt19-en-ru/README.md
Normal file
94
model_cards/facebook/wmt19-en-ru/README.md
Normal file
@ -0,0 +1,94 @@
|
||||
|
||||
---
|
||||
|
||||
<!-- This file has been auto-generated by src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py - DO NOT EDIT or your changes will be lost -->
|
||||
|
||||
language: en, ru
|
||||
thumbnail:
|
||||
tags:
|
||||
- translation
|
||||
- wmt19
|
||||
license: Apache 2.0
|
||||
datasets:
|
||||
- http://www.statmt.org/wmt19/ ([test-set](http://matrix.statmt.org/test_sets/newstest2019.tgz?1556572561))
|
||||
metrics:
|
||||
- http://www.statmt.org/wmt19/metrics-task.html
|
||||
---
|
||||
|
||||
# FSMT
|
||||
|
||||
## Model description
|
||||
|
||||
This is a ported version of [fairseq wmt19 transformer](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md) for en-ru.
|
||||
|
||||
For more details, please see, [Facebook FAIR's WMT19 News Translation Task Submission](https://arxiv.org/abs/1907.06616).
|
||||
|
||||
The abbreviation FSMT stands for FairSeqMachineTranslation
|
||||
|
||||
All four models are available:
|
||||
|
||||
* [wmt19-en-ru](https://huggingface.co/facebook/wmt19-en-ru)
|
||||
* [wmt19-ru-en](https://huggingface.co/facebook/wmt19-ru-en)
|
||||
* [wmt19-en-de](https://huggingface.co/facebook/wmt19-en-de)
|
||||
* [wmt19-de-en](https://huggingface.co/facebook/wmt19-de-en)
|
||||
|
||||
## Intended uses & limitations
|
||||
|
||||
#### How to use
|
||||
|
||||
```python
|
||||
from transformers.tokenization_fsmt import FSMTTokenizer
|
||||
from transformers.modeling_fsmt import FSMTForConditionalGeneration
|
||||
mname = "facebook/wmt19-en-ru"
|
||||
tokenizer = FSMTTokenizer.from_pretrained(mname)
|
||||
model = FSMTForConditionalGeneration.from_pretrained(mname)
|
||||
|
||||
input = "Machine learning is great, isn't it?"
|
||||
input_ids = tokenizer.encode(input, return_tensors="pt")
|
||||
outputs = model.generate(input_ids)
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(decoded) # Машинное обучение - это здорово, не так ли?
|
||||
|
||||
```
|
||||
|
||||
#### Limitations and bias
|
||||
|
||||
- The original (and this ported model) doesn't seem to handle well inputs with repeated sub-phrases, [content gets truncated](https://discuss.huggingface.co/t/issues-with-translating-inputs-containing-repeated-phrases/981)
|
||||
|
||||
## Training data
|
||||
|
||||
Pretrained weights were left identical to the original model released by fairseq. For more details, please, see the [paper](https://arxiv.org/abs/1907.06616).
|
||||
|
||||
## Eval results
|
||||
|
||||
pair | fairseq | transformers
|
||||
-------|---------|----------
|
||||
en-ru | [36.4](http://matrix.statmt.org/matrix/output/1914?run_id=6724) | 33.47
|
||||
|
||||
The score is slightly below the score reported by `fairseq`, since `transformers`` currently doesn't support:
|
||||
- model ensemble, therefore the best performing checkpoint was ported (``model4.pt``).
|
||||
- re-ranking
|
||||
|
||||
The score was calculated using this code:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
export PAIR=en-ru
|
||||
export DATA_DIR=data/$PAIR
|
||||
export SAVE_DIR=data/$PAIR
|
||||
export BS=8
|
||||
export NUM_BEAMS=15
|
||||
mkdir -p $DATA_DIR
|
||||
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
||||
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
||||
echo $PAIR
|
||||
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
|
||||
```
|
||||
note: fairseq reports using a beam of 50, so you should get a slightly higher score if re-run with `--num_beams 50`.
|
||||
|
||||
|
||||
## TODO
|
||||
|
||||
- port model ensemble (fairseq uses 4 model checkpoints)
|
||||
|
94
model_cards/facebook/wmt19-ru-en/README.md
Normal file
94
model_cards/facebook/wmt19-ru-en/README.md
Normal file
@ -0,0 +1,94 @@
|
||||
|
||||
---
|
||||
|
||||
<!-- This file has been auto-generated by src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py - DO NOT EDIT or your changes will be lost -->
|
||||
|
||||
language: ru, en
|
||||
thumbnail:
|
||||
tags:
|
||||
- translation
|
||||
- wmt19
|
||||
license: Apache 2.0
|
||||
datasets:
|
||||
- http://www.statmt.org/wmt19/ ([test-set](http://matrix.statmt.org/test_sets/newstest2019.tgz?1556572561))
|
||||
metrics:
|
||||
- http://www.statmt.org/wmt19/metrics-task.html
|
||||
---
|
||||
|
||||
# FSMT
|
||||
|
||||
## Model description
|
||||
|
||||
This is a ported version of [fairseq wmt19 transformer](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md) for ru-en.
|
||||
|
||||
For more details, please see, [Facebook FAIR's WMT19 News Translation Task Submission](https://arxiv.org/abs/1907.06616).
|
||||
|
||||
The abbreviation FSMT stands for FairSeqMachineTranslation
|
||||
|
||||
All four models are available:
|
||||
|
||||
* [wmt19-en-ru](https://huggingface.co/facebook/wmt19-en-ru)
|
||||
* [wmt19-ru-en](https://huggingface.co/facebook/wmt19-ru-en)
|
||||
* [wmt19-en-de](https://huggingface.co/facebook/wmt19-en-de)
|
||||
* [wmt19-de-en](https://huggingface.co/facebook/wmt19-de-en)
|
||||
|
||||
## Intended uses & limitations
|
||||
|
||||
#### How to use
|
||||
|
||||
```python
|
||||
from transformers.tokenization_fsmt import FSMTTokenizer
|
||||
from transformers.modeling_fsmt import FSMTForConditionalGeneration
|
||||
mname = "facebook/wmt19-ru-en"
|
||||
tokenizer = FSMTTokenizer.from_pretrained(mname)
|
||||
model = FSMTForConditionalGeneration.from_pretrained(mname)
|
||||
|
||||
input = "Машинное обучение - это здорово, не так ли?"
|
||||
input_ids = tokenizer.encode(input, return_tensors="pt")
|
||||
outputs = model.generate(input_ids)
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(decoded) # Machine learning is great, isn't it?
|
||||
|
||||
```
|
||||
|
||||
#### Limitations and bias
|
||||
|
||||
- The original (and this ported model) doesn't seem to handle well inputs with repeated sub-phrases, [content gets truncated](https://discuss.huggingface.co/t/issues-with-translating-inputs-containing-repeated-phrases/981)
|
||||
|
||||
## Training data
|
||||
|
||||
Pretrained weights were left identical to the original model released by fairseq. For more details, please, see the [paper](https://arxiv.org/abs/1907.06616).
|
||||
|
||||
## Eval results
|
||||
|
||||
pair | fairseq | transformers
|
||||
-------|---------|----------
|
||||
ru-en | [41.3](http://matrix.statmt.org/matrix/output/1907?run_id=6937) | 39.20
|
||||
|
||||
The score is slightly below the score reported by `fairseq`, since `transformers`` currently doesn't support:
|
||||
- model ensemble, therefore the best performing checkpoint was ported (``model4.pt``).
|
||||
- re-ranking
|
||||
|
||||
The score was calculated using this code:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
export PAIR=ru-en
|
||||
export DATA_DIR=data/$PAIR
|
||||
export SAVE_DIR=data/$PAIR
|
||||
export BS=8
|
||||
export NUM_BEAMS=15
|
||||
mkdir -p $DATA_DIR
|
||||
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
||||
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
||||
echo $PAIR
|
||||
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
|
||||
```
|
||||
note: fairseq reports using a beam of 50, so you should get a slightly higher score if re-run with `--num_beams 50`.
|
||||
|
||||
|
||||
## TODO
|
||||
|
||||
- port model ensemble (fairseq uses 4 model checkpoints)
|
||||
|
@ -30,6 +30,7 @@ from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
|
||||
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
|
||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
||||
from .configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
|
||||
from .configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
@ -158,6 +159,7 @@ from .tokenization_dpr import (
|
||||
)
|
||||
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||
from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_fsmt import FSMTTokenizer
|
||||
from .tokenization_funnel import FunnelTokenizer, FunnelTokenizerFast
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||
@ -338,6 +340,7 @@ if is_torch_available():
|
||||
FlaubertModel,
|
||||
FlaubertWithLMHeadModel,
|
||||
)
|
||||
from .modeling_fsmt import FSMTForConditionalGeneration, FSMTModel, PretrainedFSMTModel
|
||||
from .modeling_funnel import (
|
||||
FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
FunnelBaseModel,
|
||||
|
@ -27,6 +27,7 @@ from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
|
||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
||||
from .configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
|
||||
from .configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
@ -66,6 +67,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -94,6 +96,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
("longformer", LongformerConfig),
|
||||
("roberta", RobertaConfig),
|
||||
("flaubert", FlaubertConfig),
|
||||
("fsmt", FSMTConfig),
|
||||
("bert", BertConfig),
|
||||
("openai-gpt", OpenAIGPTConfig),
|
||||
("gpt2", GPT2Config),
|
||||
@ -126,6 +129,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("longformer", "Longformer"),
|
||||
("roberta", "RoBERTa"),
|
||||
("flaubert", "FlauBERT"),
|
||||
("fsmt", "FairSeq Machine-Translation"),
|
||||
("bert", "BERT"),
|
||||
("openai-gpt", "OpenAI GPT"),
|
||||
("gpt2", "OpenAI GPT-2"),
|
||||
|
223
src/transformers/configuration_fsmt.py
Normal file
223
src/transformers/configuration_fsmt.py
Normal file
@ -0,0 +1,223 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" FSMT configuration """
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import add_start_docstrings_to_callable
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||
|
||||
|
||||
FSMT_CONFIG_ARGS_DOC = r"""
|
||||
Args:
|
||||
langs (:obj:`List[str]`):
|
||||
source language, target_language (e.g. ['en', 'ru'])
|
||||
src_vocab_size (:obj:`int`):
|
||||
defines the different tokens that can be represented by `inputs_ids` passed to the forward
|
||||
method in the encoder.
|
||||
tgt_vocab_size (:obj:`int`):
|
||||
defines the different tokens that can be represented by `inputs_ids` passed to the forward
|
||||
method in the decoder.
|
||||
d_model (:obj:`int`, `optional`, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
encoder_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of encoder layers, 16 for pegasus, 6 for bart-base and marian
|
||||
decoder_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of decoder layers, 16 for pegasus, 6 for bart-base and marian
|
||||
encoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to "relu"):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
||||
dropout (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
init_std (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
bos_token_id (:obj:`int`, `optional`, defaults to 0)
|
||||
Beginning of stream token id.
|
||||
pad_token_id (:obj:`int`, `optional`, defaults to 1)
|
||||
Padding token id.
|
||||
eos_token_id (:obj:`int`, `optional`, defaults to 2)
|
||||
End of stream token id.
|
||||
decoder_start_token_id (:obj:`int`, `optional`):
|
||||
This model starts decoding with `eos_token_id`
|
||||
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||
Google "layerdrop arxiv", as its not explainable in one line.
|
||||
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||
Google "layerdrop arxiv", as its not explainable in one line.
|
||||
is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether this is an encoder/decoder model.
|
||||
tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to tie input and output embeddings.
|
||||
num_beams (:obj:`int`, `optional`, defaults to 5)
|
||||
Number of beams for beam search that will be used by default in the :obj:`generate` method
|
||||
of the model. 1 means no beam search.
|
||||
length_penalty (:obj:`float`, `optional`, defaults to 1)
|
||||
Exponential penalty to the length that will be used by default in the :obj:`generate` method
|
||||
of the model.
|
||||
early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`)
|
||||
Flag that will be used by default in the :obj:`generate` method of the model. Whether to stop
|
||||
the beam search when at least ``num_beams`` sentences are finished per batch or not.
|
||||
"""
|
||||
|
||||
|
||||
class DecoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
Configuration class for FSMT's decoder specific things.
|
||||
note: this is a private helper class
|
||||
"""
|
||||
model_type = "fsmt_decoder"
|
||||
|
||||
def __init__(self, vocab_size=0, bos_token_id=0):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
|
||||
@add_start_docstrings_to_callable(FSMT_CONFIG_ARGS_DOC)
|
||||
class FSMTConfig(PretrainedConfig):
|
||||
r"""
|
||||
Configuration class for FSMT.
|
||||
"""
|
||||
model_type = "fsmt"
|
||||
|
||||
# update the defaults from config file
|
||||
def __init__(
|
||||
self,
|
||||
langs,
|
||||
src_vocab_size,
|
||||
tgt_vocab_size,
|
||||
activation_function="relu",
|
||||
d_model=1024,
|
||||
max_length=200,
|
||||
max_position_embeddings=1024,
|
||||
encoder_ffn_dim=4096,
|
||||
encoder_layers=12,
|
||||
encoder_attention_heads=16,
|
||||
encoder_layerdrop=0.0,
|
||||
decoder_ffn_dim=4096,
|
||||
decoder_layers=12,
|
||||
decoder_attention_heads=16,
|
||||
decoder_layerdrop=0.0,
|
||||
attention_dropout=0.0,
|
||||
dropout=0.1,
|
||||
activation_dropout=0.0,
|
||||
init_std=0.02,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
decoder_start_token_id=2,
|
||||
is_encoder_decoder=True,
|
||||
scale_embedding=True,
|
||||
tie_word_embeddings=False,
|
||||
num_beams=5,
|
||||
length_penalty=1.0,
|
||||
early_stopping=False,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
:class:`~transformers.FSMTConfig` is the configuration class for `FSMTModel`.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import FSMTConfig, FSMTModel
|
||||
|
||||
>>> config = FSMTConfig.from_pretrained('facebook/wmt19-en-ru')
|
||||
>>> model = FSMTModel(config)
|
||||
|
||||
"""
|
||||
if "hidden_size" in common_kwargs:
|
||||
raise ValueError("hidden size is called d_model")
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**common_kwargs,
|
||||
)
|
||||
self.langs = langs
|
||||
self.src_vocab_size = src_vocab_size
|
||||
self.tgt_vocab_size = tgt_vocab_size
|
||||
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
|
||||
self.max_length = max_length
|
||||
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = self.num_hidden_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.decoder_layerdrop = decoder_layerdrop
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.init_std = init_std # Normal(0, this parameter)
|
||||
self.activation_function = activation_function
|
||||
|
||||
self.num_beams = num_beams
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
|
||||
self.decoder = DecoderConfig(vocab_size=tgt_vocab_size, bos_token_id=eos_token_id)
|
||||
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
# 3 Types of Dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.dropout = dropout
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
|
||||
|
||||
Returns:
|
||||
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["decoder"] = self.decoder.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
271
src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
Executable file
271
src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
Executable file
@ -0,0 +1,271 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: if you intend to run this script make sure you look under scripts/fsmt/
|
||||
# to locate the appropriate script to do the work correctly. There is a set of scripts to:
|
||||
# - download and prepare data and run the conversion script
|
||||
# - perform eval to get the best hparam into the config
|
||||
# - generate model_cards - useful if you have multiple models from the same paper
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from os.path import basename, dirname
|
||||
|
||||
import fairseq
|
||||
import torch
|
||||
from fairseq import hub_utils
|
||||
from fairseq.data.dictionary import Dictionary
|
||||
|
||||
from transformers import WEIGHTS_NAME
|
||||
from transformers.configuration_fsmt import FSMTConfig
|
||||
from transformers.modeling_fsmt import FSMTForConditionalGeneration
|
||||
from transformers.tokenization_fsmt import VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
json_indent = 2
|
||||
|
||||
# based on the results of a search on a range of `num_beams`, `length_penalty` and `early_stopping`
|
||||
# values against wmt19 test data to obtain the best BLEU scores, we will use the following defaults:
|
||||
#
|
||||
# * `num_beams`: 5 (higher scores better, but requires more memory/is slower, can be adjusted by users)
|
||||
# * `early_stopping`: `False` consistently scored better
|
||||
# * `length_penalty` varied, so will assign the best one depending on the model
|
||||
best_score_hparams = {
|
||||
# fairseq:
|
||||
"wmt19-ru-en": {"length_penalty": 1.1},
|
||||
"wmt19-en-ru": {"length_penalty": 1.15},
|
||||
"wmt19-en-de": {"length_penalty": 1.0},
|
||||
"wmt19-de-en": {"length_penalty": 1.1},
|
||||
# allenai:
|
||||
"wmt16-en-de-dist-12-1": {"length_penalty": 0.6},
|
||||
"wmt16-en-de-dist-6-1": {"length_penalty": 0.6},
|
||||
"wmt16-en-de-12-1": {"length_penalty": 0.8},
|
||||
"wmt19-de-en-6-6-base": {"length_penalty": 0.6},
|
||||
"wmt19-de-en-6-6-big": {"length_penalty": 0.6},
|
||||
}
|
||||
|
||||
# this remaps the different models to their organization names
|
||||
org_names = {}
|
||||
for m in ["wmt19-ru-en", "wmt19-en-ru", "wmt19-en-de", "wmt19-de-en"]:
|
||||
org_names[m] = "facebook"
|
||||
for m in [
|
||||
"wmt16-en-de-dist-12-1",
|
||||
"wmt16-en-de-dist-6-1",
|
||||
"wmt16-en-de-12-1",
|
||||
"wmt19-de-en-6-6-base",
|
||||
"wmt19-de-en-6-6-big",
|
||||
]:
|
||||
org_names[m] = "allenai"
|
||||
|
||||
|
||||
def rewrite_dict_keys(d):
|
||||
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
|
||||
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
|
||||
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
|
||||
keep_keys = "<s> <pad> </s> <unk>".split()
|
||||
# restore the special tokens
|
||||
for k in keep_keys:
|
||||
del d2[f"{k}</w>"]
|
||||
d2[k] = d[k] # restore
|
||||
return d2
|
||||
|
||||
|
||||
def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder_path):
|
||||
|
||||
# prep
|
||||
assert os.path.exists(fsmt_checkpoint_path)
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
print(f"Writing results to {pytorch_dump_folder_path}")
|
||||
|
||||
# handle various types of models
|
||||
|
||||
checkpoint_file = basename(fsmt_checkpoint_path)
|
||||
fsmt_folder_path = dirname(fsmt_checkpoint_path)
|
||||
|
||||
cls = fairseq.model_parallel.models.transformer.ModelParallelTransformerModel
|
||||
models = cls.hub_models()
|
||||
kwargs = {"bpe": "fastbpe", "tokenizer": "moses"}
|
||||
data_name_or_path = "."
|
||||
# note: since the model dump is old, fairseq has upgraded its model some
|
||||
# time later, and it does a whole lot of rewrites and splits on the saved
|
||||
# weights, therefore we can't use torch.load() directly on the model file.
|
||||
# see: upgrade_state_dict(state_dict) in fairseq_model.py
|
||||
print(f"using checkpoint {checkpoint_file}")
|
||||
chkpt = hub_utils.from_pretrained(
|
||||
fsmt_folder_path, checkpoint_file, data_name_or_path, archive_map=models, **kwargs
|
||||
)
|
||||
|
||||
args = dict(vars(chkpt["args"]))
|
||||
|
||||
src_lang = args["source_lang"]
|
||||
tgt_lang = args["target_lang"]
|
||||
|
||||
data_root = dirname(pytorch_dump_folder_path)
|
||||
model_dir = basename(pytorch_dump_folder_path)
|
||||
|
||||
# dicts
|
||||
src_dict_file = os.path.join(fsmt_folder_path, f"dict.{src_lang}.txt")
|
||||
tgt_dict_file = os.path.join(fsmt_folder_path, f"dict.{tgt_lang}.txt")
|
||||
|
||||
src_dict = Dictionary.load(src_dict_file)
|
||||
src_vocab = rewrite_dict_keys(src_dict.indices)
|
||||
src_vocab_size = len(src_vocab)
|
||||
src_vocab_file = os.path.join(pytorch_dump_folder_path, "vocab-src.json")
|
||||
print(f"Generating {src_vocab_file}")
|
||||
with open(src_vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
tgt_dict = Dictionary.load(tgt_dict_file)
|
||||
tgt_vocab = rewrite_dict_keys(tgt_dict.indices)
|
||||
tgt_vocab_size = len(tgt_vocab)
|
||||
tgt_vocab_file = os.path.join(pytorch_dump_folder_path, "vocab-tgt.json")
|
||||
print(f"Generating {tgt_vocab_file}")
|
||||
with open(tgt_vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tgt_vocab, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# merges_file (bpecodes)
|
||||
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
|
||||
fsmt_merges_file = os.path.join(fsmt_folder_path, "bpecodes")
|
||||
with open(fsmt_merges_file, encoding="utf-8") as fin:
|
||||
merges = fin.read()
|
||||
merges = re.sub(r" \d+$", "", merges, 0, re.M) # remove frequency number
|
||||
print(f"Generating {merges_file}")
|
||||
with open(merges_file, "w", encoding="utf-8") as fout:
|
||||
fout.write(merges)
|
||||
|
||||
# model config
|
||||
fsmt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
|
||||
|
||||
# validate bpe/tokenizer config, as currently it's hardcoded to moses+fastbpe -
|
||||
# may have to modify the tokenizer if a different type is used by a future model
|
||||
assert args["bpe"] == "fastbpe", f"need to extend tokenizer to support bpe={args['bpe']}"
|
||||
assert args["tokenizer"] == "moses", f"need to extend tokenizer to support bpe={args['tokenizer']}"
|
||||
|
||||
model_conf = {
|
||||
"architectures": ["FSMTForConditionalGeneration"],
|
||||
"model_type": "fsmt",
|
||||
"activation_dropout": args["activation_dropout"],
|
||||
"activation_function": "relu",
|
||||
"attention_dropout": args["attention_dropout"],
|
||||
"d_model": args["decoder_embed_dim"],
|
||||
"dropout": args["dropout"],
|
||||
"init_std": 0.02,
|
||||
"max_position_embeddings": args["max_source_positions"],
|
||||
"num_hidden_layers": args["encoder_layers"],
|
||||
"src_vocab_size": src_vocab_size,
|
||||
"tgt_vocab_size": tgt_vocab_size,
|
||||
"langs": [src_lang, tgt_lang],
|
||||
"encoder_attention_heads": args["encoder_attention_heads"],
|
||||
"encoder_ffn_dim": args["encoder_ffn_embed_dim"],
|
||||
"encoder_layerdrop": args["encoder_layerdrop"],
|
||||
"encoder_layers": args["encoder_layers"],
|
||||
"decoder_attention_heads": args["decoder_attention_heads"],
|
||||
"decoder_ffn_dim": args["decoder_ffn_embed_dim"],
|
||||
"decoder_layerdrop": args["decoder_layerdrop"],
|
||||
"decoder_layers": args["decoder_layers"],
|
||||
"bos_token_id": 0,
|
||||
"pad_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"is_encoder_decoder": True,
|
||||
"scale_embedding": not args["no_scale_embedding"],
|
||||
"tie_word_embeddings": args["share_all_embeddings"],
|
||||
}
|
||||
|
||||
# good hparam defaults to start with
|
||||
model_conf["num_beams"] = 5
|
||||
model_conf["early_stopping"] = False
|
||||
if model_dir in best_score_hparams and "length_penalty" in best_score_hparams[model_dir]:
|
||||
model_conf["length_penalty"] = best_score_hparams[model_dir]["length_penalty"]
|
||||
else:
|
||||
model_conf["length_penalty"] = 1.0
|
||||
|
||||
print(f"Generating {fsmt_model_config_file}")
|
||||
with open(fsmt_model_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# tokenizer config
|
||||
fsmt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
|
||||
|
||||
tokenizer_conf = {
|
||||
"langs": [src_lang, tgt_lang],
|
||||
"model_max_length": 1024,
|
||||
}
|
||||
|
||||
print(f"Generating {fsmt_tokenizer_config_file}")
|
||||
with open(fsmt_tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# model
|
||||
model = chkpt["models"][0]
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
# rename keys to start with 'model.'
|
||||
model_state_dict = OrderedDict(("model." + k, v) for k, v in model_state_dict.items())
|
||||
|
||||
# remove unneeded keys
|
||||
ignore_keys = [
|
||||
"model.model",
|
||||
"model.encoder.version",
|
||||
"model.decoder.version",
|
||||
"model.encoder_embed_tokens.weight",
|
||||
"model.decoder_embed_tokens.weight",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
model_state_dict.pop(k, None)
|
||||
|
||||
config = FSMTConfig.from_pretrained(pytorch_dump_folder_path)
|
||||
model_new = FSMTForConditionalGeneration(config)
|
||||
|
||||
# check that it loads ok
|
||||
model_new.load_state_dict(model_state_dict, strict=False)
|
||||
|
||||
# save
|
||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||
print(f"Generating {pytorch_weights_dump_path}")
|
||||
torch.save(model_state_dict, pytorch_weights_dump_path)
|
||||
|
||||
print("Conversion is done!")
|
||||
print("\nLast step is to upload the files to s3")
|
||||
print(f"cd {data_root}")
|
||||
print(f"transformers-cli upload {model_dir}")
|
||||
print(
|
||||
"Note: CDN caches files for up to 24h, so either use a local model path "
|
||||
"or use `from_pretrained(mname, use_cdn=False)` to use the non-cached version."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--fsmt_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts, bpecodes, etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_fsmt_checkpoint_to_pytorch(args.fsmt_checkpoint_path, args.pytorch_dump_folder_path)
|
@ -30,6 +30,7 @@ from .configuration_auto import (
|
||||
ElectraConfig,
|
||||
EncoderDecoderConfig,
|
||||
FlaubertConfig,
|
||||
FSMTConfig,
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
LongformerConfig,
|
||||
@ -113,6 +114,7 @@ from .modeling_flaubert import (
|
||||
FlaubertModel,
|
||||
FlaubertWithLMHeadModel,
|
||||
)
|
||||
from .modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
|
||||
from .modeling_funnel import (
|
||||
FunnelForMaskedLM,
|
||||
FunnelForMultipleChoice,
|
||||
@ -211,6 +213,7 @@ MODEL_MAPPING = OrderedDict(
|
||||
(TransfoXLConfig, TransfoXLModel),
|
||||
(XLNetConfig, XLNetModel),
|
||||
(FlaubertConfig, FlaubertModel),
|
||||
(FSMTConfig, FSMTModel),
|
||||
(XLMConfig, XLMModel),
|
||||
(CTRLConfig, CTRLModel),
|
||||
(ElectraConfig, ElectraModel),
|
||||
@ -230,6 +233,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
(CamembertConfig, CamembertForMaskedLM),
|
||||
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
||||
(BartConfig, BartForConditionalGeneration),
|
||||
(FSMTConfig, FSMTForConditionalGeneration),
|
||||
(LongformerConfig, LongformerForMaskedLM),
|
||||
(RobertaConfig, RobertaForMaskedLM),
|
||||
(BertConfig, BertForPreTraining),
|
||||
@ -319,6 +323,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
(MarianConfig, MarianMTModel),
|
||||
(MBartConfig, MBartForConditionalGeneration),
|
||||
(BartConfig, BartForConditionalGeneration),
|
||||
(FSMTConfig, FSMTForConditionalGeneration),
|
||||
(EncoderDecoderConfig, EncoderDecoderModel),
|
||||
]
|
||||
)
|
||||
|
1212
src/transformers/modeling_fsmt.py
Normal file
1212
src/transformers/modeling_fsmt.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -29,6 +29,7 @@ from .configuration_auto import (
|
||||
ElectraConfig,
|
||||
EncoderDecoderConfig,
|
||||
FlaubertConfig,
|
||||
FSMTConfig,
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
LongformerConfig,
|
||||
@ -59,6 +60,7 @@ from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
|
||||
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||
from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_fsmt import FSMTTokenizer
|
||||
from .tokenization_funnel import FunnelTokenizer, FunnelTokenizerFast
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||
@ -109,6 +111,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(FlaubertConfig, (FlaubertTokenizer, None)),
|
||||
(XLMConfig, (XLMTokenizer, None)),
|
||||
(CTRLConfig, (CTRLTokenizer, None)),
|
||||
(FSMTConfig, (FSMTTokenizer, None)),
|
||||
(BertGenerationConfig, (BertGenerationTokenizer, None)),
|
||||
]
|
||||
)
|
||||
|
535
src/transformers/tokenization_fsmt.py
Normal file
535
src/transformers/tokenization_fsmt.py
Normal file
@ -0,0 +1,535 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for FSMT."""
|
||||
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import sacremoses as sm
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
from .tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"src_vocab_file": "vocab-src.json",
|
||||
"tgt_vocab_file": "vocab-tgt.json",
|
||||
"merges_file": "merges.txt",
|
||||
}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {}
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
|
||||
PRETRAINED_INIT_CONFIGURATION = {}
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""
|
||||
Return set of symbol pairs in a word.
|
||||
word is represented as tuple of symbols (symbols being variable-length strings)
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def replace_unicode_punct(text):
|
||||
"""
|
||||
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
|
||||
"""
|
||||
text = text.replace(",", ",")
|
||||
text = re.sub(r"。\s*", ". ", text)
|
||||
text = text.replace("、", ",")
|
||||
text = text.replace("”", '"')
|
||||
text = text.replace("“", '"')
|
||||
text = text.replace("∶", ":")
|
||||
text = text.replace(":", ":")
|
||||
text = text.replace("?", "?")
|
||||
text = text.replace("《", '"')
|
||||
text = text.replace("》", '"')
|
||||
text = text.replace(")", ")")
|
||||
text = text.replace("!", "!")
|
||||
text = text.replace("(", "(")
|
||||
text = text.replace(";", ";")
|
||||
text = text.replace("1", "1")
|
||||
text = text.replace("」", '"')
|
||||
text = text.replace("「", '"')
|
||||
text = text.replace("0", "0")
|
||||
text = text.replace("3", "3")
|
||||
text = text.replace("2", "2")
|
||||
text = text.replace("5", "5")
|
||||
text = text.replace("6", "6")
|
||||
text = text.replace("9", "9")
|
||||
text = text.replace("7", "7")
|
||||
text = text.replace("8", "8")
|
||||
text = text.replace("4", "4")
|
||||
text = re.sub(r".\s*", ". ", text)
|
||||
text = text.replace("~", "~")
|
||||
text = text.replace("’", "'")
|
||||
text = text.replace("…", "...")
|
||||
text = text.replace("━", "-")
|
||||
text = text.replace("〈", "<")
|
||||
text = text.replace("〉", ">")
|
||||
text = text.replace("【", "[")
|
||||
text = text.replace("】", "]")
|
||||
text = text.replace("%", "%")
|
||||
return text
|
||||
|
||||
|
||||
def remove_non_printing_char(text):
|
||||
"""
|
||||
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
|
||||
"""
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
# Porting notes:
|
||||
# this one is modeled after XLMTokenizer
|
||||
#
|
||||
# added:
|
||||
# - src_vocab_file,
|
||||
# - tgt_vocab_file,
|
||||
# - langs,
|
||||
|
||||
|
||||
class FSMTTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
BPE tokenizer for FSMT (fairseq transformer)
|
||||
See: https://github.com/pytorch/fairseq/tree/master/examples/wmt19
|
||||
|
||||
- Moses preprocessing & tokenization for most supported languages
|
||||
- (optionally) lower case & normalize all inputs text
|
||||
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
|
||||
(ex: "__classify__") to a vocabulary
|
||||
- `langs` defines a pair of languages
|
||||
|
||||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
||||
should refer to the superclass for more information regarding methods.
|
||||
|
||||
Args:
|
||||
langs (:obj:`List[str]`):
|
||||
a list of two languages to translate from and to, e.g. ``["en", "ru"]``.
|
||||
src_vocab_file (:obj:`string`):
|
||||
Source language vocabulary file.
|
||||
tgt_vocab_file (:obj:`string`):
|
||||
Target language vocabulary file.
|
||||
merges_file (:obj:`string`):
|
||||
Merges file.
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to lowercase the input when tokenizing.
|
||||
unk_token (:obj:`string`, `optional`, defaults to "<unk>"):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
bos_token (:obj:`string`, `optional`, defaults to "<s>"):
|
||||
The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token.
|
||||
|
||||
.. note::
|
||||
|
||||
When building a sequence using special tokens, this is not the token that is used for the beginning
|
||||
of sequence. The token used is the :obj:`cls_token`.
|
||||
sep_token (:obj:`string`, `optional`, defaults to "</s>"):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
||||
for sequence classification or for a text and a question for question answering.
|
||||
It is also used as the last token of a sequence built with special tokens.
|
||||
pad_token (:obj:`string`, `optional`, defaults to "<pad>"):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
langs=None,
|
||||
src_vocab_file=None,
|
||||
tgt_vocab_file=None,
|
||||
merges_file=None,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
sep_token="</s>",
|
||||
pad_token="<pad>",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
sep_token=sep_token,
|
||||
pad_token=pad_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.src_vocab_file = src_vocab_file
|
||||
self.tgt_vocab_file = tgt_vocab_file
|
||||
self.merges_file = merges_file
|
||||
|
||||
# cache of sm.MosesPunctNormalizer instance
|
||||
self.cache_moses_punct_normalizer = dict()
|
||||
# cache of sm.MosesTokenizer instance
|
||||
self.cache_moses_tokenizer = dict()
|
||||
self.cache_moses_detokenizer = dict()
|
||||
|
||||
if langs and len(langs) == 2:
|
||||
self.src_lang, self.tgt_lang = langs
|
||||
else:
|
||||
raise ValueError(
|
||||
f"arg `langs` needs to be a list of 2 langs, e.g. ['en', 'ru'], but got {langs}. "
|
||||
"Usually that means that tokenizer can't find a mapping for the given model path "
|
||||
"in PRETRAINED_VOCAB_FILES_MAP, and other maps of this tokenizer."
|
||||
)
|
||||
|
||||
with open(src_vocab_file, encoding="utf-8") as src_vocab_handle:
|
||||
self.encoder = json.load(src_vocab_handle)
|
||||
with open(tgt_vocab_file, encoding="utf-8") as tgt_vocab_handle:
|
||||
tgt_vocab = json.load(tgt_vocab_handle)
|
||||
self.decoder = {v: k for k, v in tgt_vocab.items()}
|
||||
with open(merges_file, encoding="utf-8") as merges_handle:
|
||||
merges = merges_handle.read().split("\n")[:-1]
|
||||
merges = [tuple(merge.split()[:2]) for merge in merges]
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
|
||||
# hack override
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
return self.get_src_vocab()
|
||||
|
||||
# hack override
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.src_vocab_size
|
||||
|
||||
def moses_punct_norm(self, text, lang):
|
||||
if lang not in self.cache_moses_punct_normalizer:
|
||||
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
|
||||
self.cache_moses_punct_normalizer[lang] = punct_normalizer
|
||||
return self.cache_moses_punct_normalizer[lang].normalize(text)
|
||||
|
||||
def moses_tokenize(self, text, lang):
|
||||
if lang not in self.cache_moses_tokenizer:
|
||||
moses_tokenizer = sm.MosesTokenizer(lang=lang)
|
||||
self.cache_moses_tokenizer[lang] = moses_tokenizer
|
||||
return self.cache_moses_tokenizer[lang].tokenize(
|
||||
text, aggressive_dash_splits=True, return_str=False, escape=True
|
||||
)
|
||||
|
||||
def moses_detokenize(self, tokens, lang):
|
||||
if lang not in self.cache_moses_tokenizer:
|
||||
moses_detokenizer = sm.MosesDetokenizer(lang=self.tgt_lang)
|
||||
self.cache_moses_detokenizer[lang] = moses_detokenizer
|
||||
return self.cache_moses_detokenizer[lang].detokenize(tokens)
|
||||
|
||||
def moses_pipeline(self, text, lang):
|
||||
text = replace_unicode_punct(text)
|
||||
text = self.moses_punct_norm(text, lang)
|
||||
text = remove_non_printing_char(text)
|
||||
return text
|
||||
|
||||
@property
|
||||
def src_vocab_size(self):
|
||||
return len(self.encoder)
|
||||
|
||||
@property
|
||||
def tgt_vocab_size(self):
|
||||
return len(self.decoder)
|
||||
|
||||
def get_src_vocab(self):
|
||||
return dict(self.encoder, **self.added_tokens_encoder)
|
||||
|
||||
def get_tgt_vocab(self):
|
||||
return dict(self.decoder, **self.added_tokens_decoder)
|
||||
|
||||
def bpe(self, token):
|
||||
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token + "</w>"
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
except ValueError:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
else:
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
if word == "\n </w>":
|
||||
word = "\n</w>"
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def _tokenize(self, text, lang="en", bypass_tokenizer=False):
|
||||
"""
|
||||
Tokenize a string given language code using Moses.
|
||||
|
||||
Details of tokenization:
|
||||
- [sacremoses](https://github.com/alvations/sacremoses): port of Moses
|
||||
- Install with `pip install sacremoses`
|
||||
|
||||
Args:
|
||||
- lang: ISO language code (default = 'en') (string). Languages should belong of the model supported languages. However, we don't enforce it.
|
||||
- bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) (bool). If True, we only apply BPE.
|
||||
|
||||
Returns:
|
||||
List of tokens.
|
||||
"""
|
||||
# ignore `lang` which is currently isn't explicitly passed in tokenization_utils.py and always results in lang=en
|
||||
# if lang != self.src_lang:
|
||||
# raise ValueError(f"Expected lang={self.src_lang}, but got {lang}")
|
||||
lang = self.src_lang
|
||||
|
||||
if bypass_tokenizer:
|
||||
text = text.split()
|
||||
else:
|
||||
text = self.moses_pipeline(text, lang=lang)
|
||||
text = self.moses_tokenize(text, lang=lang)
|
||||
|
||||
split_tokens = []
|
||||
for token in text:
|
||||
if token:
|
||||
split_tokens.extend([t for t in self.bpe(token).split(" ")])
|
||||
|
||||
return split_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
|
||||
# remove BPE
|
||||
tokens = [t.replace(" ", "").replace("</w>", " ") for t in tokens]
|
||||
tokens = "".join(tokens).split()
|
||||
# detokenize
|
||||
text = self.moses_detokenize(tokens, self.tgt_lang)
|
||||
return text
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
A FAIRSEQ_TRANSFORMER sequence has the following format:
|
||||
|
||||
- single sequence: ``<s> X </s>``
|
||||
- pair of sequences: ``<s> A </s> B </s>``
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
|
||||
# no bos used in fairseq
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0 + sep
|
||||
return token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` methods.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Set to True if the token list is already formatted with special tokens for the model
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
)
|
||||
return list(
|
||||
map(
|
||||
lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
|
||||
token_ids_0,
|
||||
)
|
||||
)
|
||||
# no bos used in fairseq
|
||||
if token_ids_1 is not None:
|
||||
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
return ([0] * len(token_ids_0)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
||||
An FAIRSEQ_TRANSFORMER sequence pair mask has the following format:
|
||||
|
||||
::
|
||||
|
||||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence |
|
||||
|
||||
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
||||
sequence(s).
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
|
||||
# no bos used in fairseq
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0 + sep) * [0]
|
||||
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = "pt",
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
) -> BatchEncoding:
|
||||
"""Prepare model inputs for translation. For best performance, translate one sentence at a time."""
|
||||
|
||||
if type(src_texts) is not list:
|
||||
raise ValueError("src_texts is expected to be a list")
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
|
||||
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||
return model_inputs
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (:obj:`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
|
||||
Returns:
|
||||
:obj:`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
|
||||
src_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["src_vocab_file"])
|
||||
tgt_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["tgt_vocab_file"])
|
||||
merges_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
|
||||
|
||||
with open(src_vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
with open(tgt_vocab_file, "w", encoding="utf-8") as f:
|
||||
tgt_vocab = {v: k for k, v in self.decoder.items()}
|
||||
f.write(json.dumps(tgt_vocab, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merges_file, "w", encoding="utf-8") as writer:
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
"Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merges_file)
|
||||
)
|
||||
index = token_index
|
||||
writer.write(" ".join(bpe_tokens) + "\n")
|
||||
index += 1
|
||||
|
||||
return src_vocab_file, tgt_vocab_file, merges_file
|
501
tests/test_modeling_fsmt.py
Normal file
501
tests/test_modeling_fsmt.py
Normal file
@ -0,0 +1,501 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Huggingface
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import timeout_decorator # noqa
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer
|
||||
from transformers.modeling_fsmt import (
|
||||
SinusoidalPositionalEmbedding,
|
||||
_prepare_fsmt_decoder_inputs,
|
||||
invert_mask,
|
||||
shift_tokens_right,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
):
|
||||
self.parent = parent
|
||||
self.src_vocab_size = 99
|
||||
self.tgt_vocab_size = 99
|
||||
self.langs = ["ru", "en"]
|
||||
self.batch_size = 13
|
||||
self.seq_length = 7
|
||||
self.is_training = False
|
||||
self.use_labels = False
|
||||
self.hidden_size = 16
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 4
|
||||
self.hidden_act = "relu"
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
self.max_position_embeddings = 20
|
||||
self.bos_token_id = 0
|
||||
self.pad_token_id = 1
|
||||
self.eos_token_id = 2
|
||||
torch.manual_seed(0)
|
||||
|
||||
# hack needed for modeling_common tests - despite not really having this attribute in this model
|
||||
self.vocab_size = self.src_vocab_size
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.src_vocab_size).clamp(
|
||||
3,
|
||||
)
|
||||
input_ids[:, -1] = 2 # Eos Token
|
||||
|
||||
config = FSMTConfig(
|
||||
vocab_size=self.src_vocab_size, # hack needed for common tests
|
||||
src_vocab_size=self.src_vocab_size,
|
||||
tgt_vocab_size=self.tgt_vocab_size,
|
||||
langs=self.langs,
|
||||
d_model=self.hidden_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
decoder_layers=self.num_hidden_layers,
|
||||
encoder_attention_heads=self.num_attention_heads,
|
||||
decoder_attention_heads=self.num_attention_heads,
|
||||
encoder_ffn_dim=self.intermediate_size,
|
||||
decoder_ffn_dim=self.intermediate_size,
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
inputs_dict = prepare_fsmt_inputs_dict(config, input_ids)
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
def prepare_fsmt_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FSMTModel, FSMTForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
# TODO(SS): fix the below in a separate PR
|
||||
test_pruning = False
|
||||
test_torchscript = True
|
||||
test_head_masking = False
|
||||
test_resize_embeddings = True # This requires inputs_dict['input_ids']
|
||||
test_missing_keys = False # because FSMTForConditionalGeneration and FSMTModel now have identical state_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ModelTester(self)
|
||||
self.langs = ["en", "ru"]
|
||||
config = {
|
||||
"langs": self.langs,
|
||||
"src_vocab_size": 10,
|
||||
"tgt_vocab_size": 20,
|
||||
}
|
||||
# XXX: hack to appease to all other models requiring `vocab_size`
|
||||
config["vocab_size"] = 99 # no such thing in FSMT
|
||||
self.config_tester = ConfigTester(self, config_class=FSMTConfig, **config)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# XXX: override test_model_common_attributes / different Embedding type
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Embedding))
|
||||
model.set_input_embeddings(torch.nn.Embedding(10, 10))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, torch.nn.modules.sparse.Embedding))
|
||||
|
||||
def test_initialization_more(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = FSMTModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
# test init
|
||||
# self.assertTrue((model.encoder.embed_tokens.weight == model.shared.weight).all().item())
|
||||
|
||||
def _check_var(module):
|
||||
"""Check that we initialized various parameters from N(0, config.init_std)."""
|
||||
self.assertAlmostEqual(torch.std(module.weight).item(), config.init_std, 2)
|
||||
|
||||
_check_var(model.encoder.embed_tokens)
|
||||
_check_var(model.encoder.layers[0].self_attn.k_proj)
|
||||
_check_var(model.encoder.layers[0].fc1)
|
||||
# XXX: different std for fairseq version of SinusoidalPositionalEmbedding
|
||||
# self.assertAlmostEqual(torch.std(model.encoder.embed_positions.weights).item(), config.init_std, 2)
|
||||
|
||||
def test_advanced_inputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
|
||||
config, inputs_dict["input_ids"]
|
||||
)
|
||||
model = FSMTModel(config).to(torch_device).eval()
|
||||
|
||||
decoder_features_with_created_mask = model(**inputs_dict)[0]
|
||||
decoder_features_with_passed_mask = model(
|
||||
decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||
useless_mask = torch.zeros_like(decoder_attn_mask)
|
||||
decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0]
|
||||
self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions
|
||||
self.assertEqual(
|
||||
decoder_features.size(),
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, config.tgt_vocab_size),
|
||||
)
|
||||
if decoder_attn_mask.min().item() < -1e3: # some tokens were masked
|
||||
self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item())
|
||||
|
||||
# Test different encoder attention masks
|
||||
decoder_features_with_long_encoder_mask = model(
|
||||
inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long()
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
|
||||
|
||||
def test_save_load_strict(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||
self.assertEqual(info["missing_keys"], [])
|
||||
|
||||
@unittest.skip("can't be implemented for FSMT due to dual vocab.")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Passing inputs_embeds not implemented for FSMT.")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("model weights aren't tied in FSMT.")
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
# def test_auto_model(self):
|
||||
# # XXX: add a tiny model to s3?
|
||||
# model_name = "facebook/wmt19-ru-en-tiny"
|
||||
# tiny = AutoModel.from_pretrained(model_name) # same vocab size
|
||||
# tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer
|
||||
# inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt")
|
||||
|
||||
# with torch.no_grad():
|
||||
# tiny(**inputs_dict)
|
||||
|
||||
|
||||
@require_torch
|
||||
class FSMTHeadTests(unittest.TestCase):
|
||||
src_vocab_size = 99
|
||||
tgt_vocab_size = 99
|
||||
langs = ["ru", "en"]
|
||||
|
||||
def _get_config(self):
|
||||
return FSMTConfig(
|
||||
src_vocab_size=self.src_vocab_size,
|
||||
tgt_vocab_size=self.tgt_vocab_size,
|
||||
langs=self.langs,
|
||||
d_model=24,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=2,
|
||||
decoder_attention_heads=2,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
def _get_config_and_data(self):
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[71, 82, 18, 33, 46, 91, 2],
|
||||
[68, 34, 26, 58, 30, 82, 2],
|
||||
[5, 97, 17, 39, 94, 40, 2],
|
||||
[76, 83, 94, 25, 70, 78, 2],
|
||||
[87, 59, 41, 35, 48, 66, 2],
|
||||
[55, 13, 16, 58, 5, 2, 1], # note padding
|
||||
[64, 27, 31, 51, 12, 75, 2],
|
||||
[52, 64, 86, 17, 83, 39, 2],
|
||||
[48, 61, 9, 24, 71, 82, 2],
|
||||
[26, 1, 60, 48, 22, 13, 2],
|
||||
[21, 5, 62, 28, 14, 76, 2],
|
||||
[45, 98, 37, 86, 59, 48, 2],
|
||||
[70, 70, 50, 9, 28, 0, 2],
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
config = self._get_config()
|
||||
return config, input_ids, batch_size
|
||||
|
||||
def test_generate_beam_search(self):
|
||||
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device)
|
||||
config = self._get_config()
|
||||
lm_model = FSMTForConditionalGeneration(config).to(torch_device)
|
||||
lm_model.eval()
|
||||
|
||||
max_length = 5
|
||||
new_input_ids = lm_model.generate(
|
||||
input_ids.clone(),
|
||||
do_sample=True,
|
||||
num_return_sequences=1,
|
||||
num_beams=2,
|
||||
no_repeat_ngram_size=3,
|
||||
max_length=max_length,
|
||||
)
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
|
||||
# TODO(SS): uneven length batches, empty inputs
|
||||
|
||||
def test_shift_tokens_right(self):
|
||||
input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long()
|
||||
shifted = shift_tokens_right(input_ids, 1)
|
||||
n_pad_before = input_ids.eq(1).float().sum()
|
||||
n_pad_after = shifted.eq(1).float().sum()
|
||||
self.assertEqual(shifted.shape, input_ids.shape)
|
||||
self.assertEqual(n_pad_after, n_pad_before - 1)
|
||||
self.assertTrue(torch.eq(shifted[:, 0], 2).all())
|
||||
|
||||
def test_generate_fp16(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
model = FSMTForConditionalGeneration(config).eval().to(torch_device)
|
||||
if torch_device == "cuda":
|
||||
model.half()
|
||||
model.generate(input_ids, attention_mask=attention_mask)
|
||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
def test_dummy_inputs(self):
|
||||
config, *_ = self._get_config_and_data()
|
||||
model = FSMTForConditionalGeneration(config).eval().to(torch_device)
|
||||
model(**model.dummy_inputs)
|
||||
|
||||
def test_prepare_fsmt_decoder_inputs(self):
|
||||
config, *_ = self._get_config_and_data()
|
||||
input_ids = _long_tensor(([4, 4, 2]))
|
||||
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
|
||||
ignore = float("-inf")
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
|
||||
config, input_ids, decoder_input_ids
|
||||
)
|
||||
expected_causal_mask = torch.tensor(
|
||||
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
|
||||
).to(input_ids.device)
|
||||
self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
|
||||
self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
if a is None and b is None:
|
||||
return True
|
||||
try:
|
||||
if torch.allclose(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
msg = "{} != {}".format(a, b)
|
||||
if prefix:
|
||||
msg = prefix + ": " + msg
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
@require_torch
|
||||
class FSMTModelIntegrationTests(unittest.TestCase):
|
||||
tokenizers_cache = {}
|
||||
models_cache = {}
|
||||
default_mname = "facebook/wmt19-en-ru"
|
||||
|
||||
@cached_property
|
||||
def default_tokenizer(self):
|
||||
return self.get_tokenizer(self.default_mname)
|
||||
|
||||
@cached_property
|
||||
def default_model(self):
|
||||
return self.get_model(self.default_mname)
|
||||
|
||||
def get_tokenizer(self, mname):
|
||||
if mname not in self.tokenizers_cache:
|
||||
self.tokenizers_cache[mname] = FSMTTokenizer.from_pretrained(mname)
|
||||
return self.tokenizers_cache[mname]
|
||||
|
||||
def get_model(self, mname):
|
||||
if mname not in self.models_cache:
|
||||
self.models_cache[mname] = FSMTForConditionalGeneration.from_pretrained(mname).to(torch_device)
|
||||
if torch_device == "cuda":
|
||||
self.models_cache[mname].half()
|
||||
return self.models_cache[mname]
|
||||
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
tokenizer = self.default_tokenizer
|
||||
model = FSMTModel.from_pretrained(self.default_mname).to(torch_device)
|
||||
|
||||
src_text = "My friend computer will translate this for me"
|
||||
input_ids = tokenizer([src_text], return_tensors="pt")["input_ids"]
|
||||
input_ids = _long_tensor(input_ids)
|
||||
inputs_dict = prepare_fsmt_inputs_dict(model.config, input_ids)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
expected_shape = torch.Size((1, 10, model.config.tgt_vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# expected numbers were generated when en-ru model, using just fairseq's model4.pt
|
||||
# may have to adjust if switched to a different checkpoint
|
||||
expected_slice = torch.tensor(
|
||||
[[-1.5753, -1.5753, 2.8975], [-0.9540, -0.9540, 1.0299], [-3.3131, -3.3131, 0.5219]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
["en-ru"],
|
||||
["ru-en"],
|
||||
["en-de"],
|
||||
["de-en"],
|
||||
]
|
||||
)
|
||||
@slow
|
||||
def test_translation(self, pair):
|
||||
text = {
|
||||
"en": "Machine learning is great, isn't it?",
|
||||
"ru": "Машинное обучение - это здорово, не так ли?",
|
||||
"de": "Maschinelles Lernen ist großartig, oder?",
|
||||
}
|
||||
|
||||
src, tgt = pair.split("-")
|
||||
print(f"Testing {src} -> {tgt}")
|
||||
mname = f"facebook/wmt19-{pair}"
|
||||
|
||||
src_sentence = text[src]
|
||||
tgt_sentence = text[tgt]
|
||||
|
||||
tokenizer = self.get_tokenizer(mname)
|
||||
model = self.get_model(mname)
|
||||
|
||||
input_ids = tokenizer.encode(src_sentence, return_tensors="pt")
|
||||
outputs = model.generate(input_ids)
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
assert decoded == tgt_sentence, f"\n\ngot: {decoded}\nexp: {tgt_sentence}\n"
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
padding_idx = 1
|
||||
tolerance = 1e-4
|
||||
|
||||
def test_basic(self):
|
||||
input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device)
|
||||
emb1 = SinusoidalPositionalEmbedding(embedding_dim=6, padding_idx=self.padding_idx, init_size=6).to(
|
||||
torch_device
|
||||
)
|
||||
emb = emb1(input_ids)
|
||||
desired_weights = torch.tensor(
|
||||
[
|
||||
[9.0930e-01, 1.9999e-02, 2.0000e-04, -4.1615e-01, 9.9980e-01, 1.0000e00],
|
||||
[1.4112e-01, 2.9995e-02, 3.0000e-04, -9.8999e-01, 9.9955e-01, 1.0000e00],
|
||||
]
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(emb[0], desired_weights, atol=self.tolerance),
|
||||
msg=f"\nexp:\n{desired_weights}\ngot:\n{emb[0]}\n",
|
||||
)
|
||||
|
||||
def test_odd_embed_dim(self):
|
||||
# odd embedding_dim is allowed
|
||||
SinusoidalPositionalEmbedding.get_embedding(
|
||||
num_embeddings=4, embedding_dim=5, padding_idx=self.padding_idx
|
||||
).to(torch_device)
|
||||
|
||||
# odd num_embeddings is allowed
|
||||
SinusoidalPositionalEmbedding.get_embedding(
|
||||
num_embeddings=5, embedding_dim=4, padding_idx=self.padding_idx
|
||||
).to(torch_device)
|
||||
|
||||
@unittest.skip("different from marian (needs more research)")
|
||||
def test_positional_emb_weights_against_marian(self):
|
||||
|
||||
desired_weights = torch.tensor(
|
||||
[
|
||||
[0, 0, 0, 0, 0],
|
||||
[0.84147096, 0.82177866, 0.80180490, 0.78165019, 0.76140374],
|
||||
[0.90929741, 0.93651021, 0.95829457, 0.97505713, 0.98720258],
|
||||
]
|
||||
)
|
||||
emb1 = SinusoidalPositionalEmbedding(init_size=512, embedding_dim=512, padding_idx=self.padding_idx).to(
|
||||
torch_device
|
||||
)
|
||||
weights = emb1.weights.data[:3, :5]
|
||||
# XXX: only the 1st and 3rd lines match - this is testing against
|
||||
# verbatim copy of SinusoidalPositionalEmbedding from fairseq
|
||||
self.assertTrue(
|
||||
torch.allclose(weights, desired_weights, atol=self.tolerance),
|
||||
msg=f"\nexp:\n{desired_weights}\ngot:\n{weights}\n",
|
||||
)
|
||||
|
||||
# test that forward pass is just a lookup, there is no ignore padding logic
|
||||
input_ids = torch.tensor(
|
||||
[[4, 10, self.padding_idx, self.padding_idx, self.padding_idx]], dtype=torch.long, device=torch_device
|
||||
)
|
||||
no_cache_pad_zero = emb1(input_ids)[0]
|
||||
# XXX: only the 1st line matches the 3rd
|
||||
self.assertTrue(
|
||||
torch.allclose(torch.tensor(desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3)
|
||||
)
|
147
tests/test_tokenization_fsmt.py
Normal file
147
tests/test_tokenization_fsmt.py
Normal file
@ -0,0 +1,147 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.tokenization_fsmt import VOCAB_FILES_NAMES, FSMTTokenizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = FSMTTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||
vocab = [
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"w</w>",
|
||||
"r</w>",
|
||||
"t</w>",
|
||||
"lo",
|
||||
"low",
|
||||
"er</w>",
|
||||
"low</w>",
|
||||
"lowest</w>",
|
||||
"newer</w>",
|
||||
"wider</w>",
|
||||
"<unk>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
||||
|
||||
self.langs = ["en", "ru"]
|
||||
config = {
|
||||
"langs": self.langs,
|
||||
"src_vocab_size": 10,
|
||||
"tgt_vocab_size": 20,
|
||||
}
|
||||
|
||||
self.src_vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["src_vocab_file"])
|
||||
self.tgt_vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["tgt_vocab_file"])
|
||||
config_file = os.path.join(self.tmpdirname, "tokenizer_config.json")
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.src_vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(self.tgt_vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
with open(config_file, "w") as fp:
|
||||
fp.write(json.dumps(config))
|
||||
|
||||
@cached_property
|
||||
def tokenizer_ru_en(self):
|
||||
return FSMTTokenizer.from_pretrained("facebook/wmt19-ru-en")
|
||||
|
||||
@cached_property
|
||||
def tokenizer_en_ru(self):
|
||||
return FSMTTokenizer.from_pretrained("facebook/wmt19-en-ru")
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||
tokenizer = FSMTTokenizer(self.langs, self.src_vocab_file, self.tgt_vocab_file, self.merges_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_ru_en
|
||||
|
||||
text = tokenizer.encode("sequence builders", add_special_tokens=False)
|
||||
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
|
||||
|
||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
assert encoded_sentence == text + [2]
|
||||
assert encoded_pair == text + [2] + text_2 + [2]
|
||||
|
||||
@slow
|
||||
def test_match_encode_decode(self):
|
||||
tokenizer_enc = self.tokenizer_en_ru
|
||||
tokenizer_dec = self.tokenizer_ru_en
|
||||
|
||||
targets = [
|
||||
[
|
||||
"Here's a little song I wrote. Don't worry, be happy.",
|
||||
[2470, 39, 11, 2349, 7222, 70, 5979, 7, 8450, 1050, 13160, 5, 26, 6445, 7, 2],
|
||||
],
|
||||
["This is it. No more. I'm done!", [132, 21, 37, 7, 1434, 86, 7, 70, 6476, 1305, 427, 2]],
|
||||
]
|
||||
|
||||
# if data needs to be recreated or added, run:
|
||||
# import torch
|
||||
# model = torch.hub.load("pytorch/fairseq", "transformer.wmt19.en-ru", checkpoint_file="model4.pt", tokenizer="moses", bpe="fastbpe")
|
||||
# for src_text, _ in targets: print(f"""[\n"{src_text}",\n {model.encode(src_text).tolist()}\n],""")
|
||||
|
||||
for src_text, tgt_input_ids in targets:
|
||||
input_ids = tokenizer_enc.encode(src_text, return_tensors="pt")[0].tolist()
|
||||
self.assertListEqual(input_ids, tgt_input_ids)
|
||||
|
||||
# and decode backward, using the reversed languages model
|
||||
decoded_text = tokenizer_dec.decode(input_ids, skip_special_tokens=True)
|
||||
self.assertEqual(decoded_text, src_text)
|
||||
|
||||
@unittest.skip("FSMTConfig.__init__ requires non-optional args")
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FSMTConfig.__init__ requires non-optional args")
|
||||
def test_np_encode_plus_sent_to_model(self):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user