mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
🚨🚨🚨 [NLLB Tokenizer]
Fix the prefix tokens 🚨🚨🚨 (#22313)
* fix the prefix tokens * update fast and test values * add legacy behaviour Co-authored-by: sgugger <sylvain.gugger@gmail.com> * update disclaimer, linkissue PR and behaviral changes * Apply suggestions from code review Co-authored-by: Lysandre Debut <hi@lysand.re> * styling * make a quote * quote this time --------- Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
parent
ad5e9b6c6a
commit
00b5887b94
@ -12,8 +12,45 @@ specific language governing permissions and limitations under the License.
|
|||||||
|
|
||||||
# NLLB
|
# NLLB
|
||||||
|
|
||||||
**DISCLAIMER:** If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=bug&template=bug-report.yml) and assign
|
**DISCLAIMER:** The default behaviour for the tokenizer has recently been fixed (and thus changed)!
|
||||||
@LysandreJik
|
|
||||||
|
The previous version adds `[self.eos_token_id, self.cur_lang_code]` at the end of the token sequence for both target and source tokenization. This is wrong as the NLLB paper mentions (page 48, 6.1.1. Model Architecture) :
|
||||||
|
|
||||||
|
*Note that we prefix the source sequence with the source language, as opposed to the target
|
||||||
|
language as previously done in several works (Arivazhagan et al., 2019; Johnson et al.,
|
||||||
|
2017). This is primarily because we prioritize optimizing zero-shot performance of our
|
||||||
|
model on any pair of 200 languages at a minor cost to supervised performance.*
|
||||||
|
|
||||||
|
Previous behaviour:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import NllbTokenizer
|
||||||
|
|
||||||
|
>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
||||||
|
>>> tokenizer("How was your day?").input_ids
|
||||||
|
[13374, 1398, 4260, 4039, 248130, 2, 256047]
|
||||||
|
|
||||||
|
>>> # 2: '</s>'
|
||||||
|
>>> # 256047 : 'eng_Latn'
|
||||||
|
```
|
||||||
|
New behaviour
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import NllbTokenizer
|
||||||
|
|
||||||
|
>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
||||||
|
>>> tokenizer("How was your day?").input_ids
|
||||||
|
[256047, 13374, 1398, 4260, 4039, 248130, 2]
|
||||||
|
```
|
||||||
|
|
||||||
|
Enabling the old behaviour can be done as follows:
|
||||||
|
```python
|
||||||
|
>>> from transformers import NllbTokenizer
|
||||||
|
|
||||||
|
>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", legacy_behaviour=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
For more details, feel free to check the linked [PR](https://github.com/huggingface/transformers/pull/22313) and [Issue](https://github.com/huggingface/transformers/issues/19943).
|
||||||
|
|
||||||
## Overview of NLLB
|
## Overview of NLLB
|
||||||
|
|
||||||
|
@ -140,12 +140,14 @@ class NllbTokenizer(PreTrainedTokenizer):
|
|||||||
tgt_lang=None,
|
tgt_lang=None,
|
||||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
additional_special_tokens=None,
|
additional_special_tokens=None,
|
||||||
|
legacy_behaviour=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Mask token behave like a normal word, i.e. include the space before it
|
# Mask token behave like a normal word, i.e. include the space before it
|
||||||
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||||
|
|
||||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||||
|
self.legacy_behaviour = legacy_behaviour
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
bos_token=bos_token,
|
bos_token=bos_token,
|
||||||
@ -160,13 +162,13 @@ class NllbTokenizer(PreTrainedTokenizer):
|
|||||||
tgt_lang=tgt_lang,
|
tgt_lang=tgt_lang,
|
||||||
additional_special_tokens=additional_special_tokens,
|
additional_special_tokens=additional_special_tokens,
|
||||||
sp_model_kwargs=self.sp_model_kwargs,
|
sp_model_kwargs=self.sp_model_kwargs,
|
||||||
|
legacy_behaviour=legacy_behaviour,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.Load(str(vocab_file))
|
self.sp_model.Load(str(vocab_file))
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
|
|
||||||
# Original fairseq vocab and spm vocab must be "aligned":
|
# Original fairseq vocab and spm vocab must be "aligned":
|
||||||
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
|
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
|
||||||
# -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
|
# -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
|
||||||
@ -388,13 +390,27 @@ class NllbTokenizer(PreTrainedTokenizer):
|
|||||||
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||||
|
|
||||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
"""Reset the special tokens to the source lang setting.
|
||||||
|
- In legacy mode: No prefix and suffix=[eos, src_lang_code].
|
||||||
|
- In default mode: Prefix=[src_lang_code], suffix = [eos]
|
||||||
|
"""
|
||||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||||
self.prefix_tokens = []
|
if self.legacy_behaviour:
|
||||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
self.prefix_tokens = []
|
||||||
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
|
else:
|
||||||
|
self.prefix_tokens = [self.cur_lang_code]
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
|
||||||
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
||||||
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
|
"""Reset the special tokens to the target lang setting.
|
||||||
|
- In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
|
||||||
|
- In default mode: Prefix=[tgt_lang_code], suffix = [eos]
|
||||||
|
"""
|
||||||
self.cur_lang_code = self.lang_code_to_id[lang]
|
self.cur_lang_code = self.lang_code_to_id[lang]
|
||||||
self.prefix_tokens = []
|
if self.legacy_behaviour:
|
||||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
self.prefix_tokens = []
|
||||||
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
|
else:
|
||||||
|
self.prefix_tokens = [self.cur_lang_code]
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
@ -151,11 +151,12 @@ class NllbTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
src_lang=None,
|
src_lang=None,
|
||||||
tgt_lang=None,
|
tgt_lang=None,
|
||||||
additional_special_tokens=None,
|
additional_special_tokens=None,
|
||||||
|
legacy_behaviour=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Mask token behave like a normal word, i.e. include the space before it
|
# Mask token behave like a normal word, i.e. include the space before it
|
||||||
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||||
|
self.legacy_behaviour = legacy_behaviour
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vocab_file=vocab_file,
|
vocab_file=vocab_file,
|
||||||
tokenizer_file=tokenizer_file,
|
tokenizer_file=tokenizer_file,
|
||||||
@ -169,6 +170,7 @@ class NllbTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
src_lang=src_lang,
|
src_lang=src_lang,
|
||||||
tgt_lang=tgt_lang,
|
tgt_lang=tgt_lang,
|
||||||
additional_special_tokens=additional_special_tokens,
|
additional_special_tokens=additional_special_tokens,
|
||||||
|
legacy_behaviour=legacy_behaviour,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -287,10 +289,18 @@ class NllbTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||||
|
|
||||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
"""Reset the special tokens to the source lang setting.
|
||||||
|
- In legacy mode: No prefix and suffix=[eos, src_lang_code].
|
||||||
|
- In default mode: Prefix=[src_lang_code], suffix = [eos]
|
||||||
|
"""
|
||||||
self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
|
self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
|
||||||
self.prefix_tokens = []
|
|
||||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
if self.legacy_behaviour:
|
||||||
|
self.prefix_tokens = []
|
||||||
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
|
else:
|
||||||
|
self.prefix_tokens = [self.cur_lang_code]
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
|
||||||
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
|
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
|
||||||
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
|
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
|
||||||
@ -302,10 +312,17 @@ class NllbTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
||||||
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
|
"""Reset the special tokens to the target lang setting.
|
||||||
|
- In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
|
||||||
|
- In default mode: Prefix=[tgt_lang_code], suffix = [eos]
|
||||||
|
"""
|
||||||
self.cur_lang_code = self.convert_tokens_to_ids(lang)
|
self.cur_lang_code = self.convert_tokens_to_ids(lang)
|
||||||
self.prefix_tokens = []
|
if self.legacy_behaviour:
|
||||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
self.prefix_tokens = []
|
||||||
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
|
else:
|
||||||
|
self.prefix_tokens = [self.cur_lang_code]
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
|
||||||
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
|
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
|
||||||
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
|
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
|
||||||
|
@ -305,6 +305,7 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
|||||||
" face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
|
" face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
|
||||||
]
|
]
|
||||||
expected_src_tokens = [
|
expected_src_tokens = [
|
||||||
|
256047,
|
||||||
16297,
|
16297,
|
||||||
134408,
|
134408,
|
||||||
8165,
|
8165,
|
||||||
@ -319,7 +320,6 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
|||||||
108,
|
108,
|
||||||
49486,
|
49486,
|
||||||
2,
|
2,
|
||||||
256047,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -355,8 +355,8 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
|||||||
assert isinstance(src_text[0], str)
|
assert isinstance(src_text[0], str)
|
||||||
desired_max_length = 10
|
desired_max_length = 10
|
||||||
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
|
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
|
||||||
self.assertEqual(ids[-2], 2)
|
self.assertEqual(ids[-1], 2)
|
||||||
self.assertEqual(ids[-1], EN_CODE)
|
self.assertEqual(ids[0], EN_CODE)
|
||||||
self.assertEqual(len(ids), desired_max_length)
|
self.assertEqual(len(ids), desired_max_length)
|
||||||
|
|
||||||
def test_mask_token(self):
|
def test_mask_token(self):
|
||||||
@ -389,10 +389,10 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual((2, 15), batch.attention_mask.shape)
|
self.assertEqual((2, 15), batch.attention_mask.shape)
|
||||||
result = batch.input_ids.tolist()[0]
|
result = batch.input_ids.tolist()[0]
|
||||||
self.assertListEqual(self.expected_src_tokens, result)
|
self.assertListEqual(self.expected_src_tokens, result)
|
||||||
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
|
self.assertEqual(RO_CODE, batch.decoder_input_ids[0, 0]) # EOS
|
||||||
# Test that special tokens are reset
|
# Test that special tokens are reset
|
||||||
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE])
|
||||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||||
|
|
||||||
def test_seq2seq_max_length(self):
|
def test_seq2seq_max_length(self):
|
||||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||||
@ -419,9 +419,27 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
|||||||
nested_simplify(inputs),
|
nested_simplify(inputs),
|
||||||
{
|
{
|
||||||
# A, test, EOS, en_XX
|
# A, test, EOS, en_XX
|
||||||
"input_ids": [[70, 7356, 2, 256047]],
|
"input_ids": [[256047, 70, 7356, 2]],
|
||||||
"attention_mask": [[1, 1, 1, 1]],
|
"attention_mask": [[1, 1, 1, 1]],
|
||||||
# ar_AR
|
# ar_AR
|
||||||
"forced_bos_token_id": 256057,
|
"forced_bos_token_id": 256057,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_legacy_behaviour(self):
|
||||||
|
self.tokenizer.legacy_behaviour = True
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
inputs.input_ids, [16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2, 256047]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tokenizer.legacy_behaviour = False
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
inputs.input_ids, [256047, 16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2]
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user