mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Refactor prepare_seq2seq_batch
(#9524)
* Add target contextmanager and rework prepare_seq2seq_batch * Fix tests, treat BART and Barthez * Add last tokenizers * Fix test * Set src token before calling the superclass * Remove special behavior for T5 * Remove needless imports * Remove needless asserts
This commit is contained in:
parent
e6ecef711e
commit
063d8d27f4
@ -13,10 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||
|
||||
@ -54,45 +50,3 @@ class BartTokenizer(RobertaTokenizer):
|
||||
"vocab_file": {m: vocab_url for m in _all_bart_models},
|
||||
"merges_file": {m: merges_url for m in _all_bart_models},
|
||||
}
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
kwargs.pop("src_lang", None)
|
||||
kwargs.pop("tgt_lang", None)
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
@ -13,10 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||
from .tokenization_bart import BartTokenizer
|
||||
@ -49,43 +45,3 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
||||
"tokenizer_file": {m: tokenizer_url for m in _all_bart_models},
|
||||
}
|
||||
slow_tokenizer_class = BartTokenizer
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: Optional[str] = None,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
@ -21,9 +21,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -264,45 +262,3 @@ class BarthezTokenizer(PreTrainedTokenizer):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "None",
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
kwargs.pop("src_lang", None)
|
||||
kwargs.pop("tgt_lang", None)
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
@ -19,8 +19,7 @@ import os
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...file_utils import is_sentencepiece_available
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
@ -228,45 +227,3 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "None",
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
kwargs.pop("src_lang", None)
|
||||
kwargs.pop("tgt_lang", None)
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
@ -23,9 +23,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import sacremoses as sm
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -484,40 +482,6 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
||||
return len(token_ids_0 + sep) * [0]
|
||||
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: Optional[str] = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
) -> BatchEncoding:
|
||||
if type(src_texts) is not list:
|
||||
raise ValueError("src_texts is expected to be a list")
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
|
||||
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||
return model_inputs
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
|
@ -15,15 +15,14 @@
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import sentencepiece
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
vocab_files_names = {
|
||||
@ -182,40 +181,15 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: Optional[str] = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
) -> BatchEncoding:
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
self.current_spm = self.spm_source
|
||||
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.current_spm = self.spm_target
|
||||
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||
yield
|
||||
self.current_spm = self.spm_source
|
||||
return model_inputs
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
|
@ -13,11 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...utils import logging
|
||||
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
|
||||
@ -172,52 +171,28 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
src_lang: str = "en_XX",
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro_RO",
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: Optional[str] = None,
|
||||
add_prefix_space: bool = False, # ignored
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
self.set_src_lang_special_tokens(src_lang)
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
||||
return model_inputs
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
|
@ -13,13 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
from tokenizers import processors
|
||||
|
||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
||||
from ...file_utils import is_sentencepiece_available
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...utils import logging
|
||||
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
|
||||
|
||||
@ -171,51 +171,28 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
src_lang: str = "en_XX",
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro_RO",
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
self.set_src_lang_special_tokens(src_lang)
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
||||
return model_inputs
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
|
@ -18,9 +18,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -250,36 +248,6 @@ class PegasusTokenizer(PreTrainedTokenizer):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
) -> BatchEncoding:
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
|
@ -19,8 +19,7 @@ import os
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...file_utils import is_sentencepiece_available
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
@ -188,36 +187,6 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
) -> BatchEncoding:
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
|
@ -17,9 +17,7 @@ import collections
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer
|
||||
|
||||
@ -288,43 +286,3 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
|
||||
return token_ids_0 + [self.sep_token_id]
|
||||
sep = [self.sep_token_id]
|
||||
return token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels_and_decoder_mask = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
|
||||
return model_inputs
|
||||
|
@ -16,8 +16,7 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_rag import RagConfig
|
||||
|
||||
@ -63,42 +62,18 @@ class RagTokenizer:
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
return self.generator.batch_decode(*args, **kwargs)
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.question_encoder.model_max_length
|
||||
model_inputs: BatchEncoding = self.question_encoder(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = self.generator.model_max_length
|
||||
labels = self.generator(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
return super().prepare_seq2seq_batch(
|
||||
src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs
|
||||
)
|
||||
|
@ -23,9 +23,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -295,43 +293,3 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels_and_decoder_mask = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
|
||||
return model_inputs
|
||||
|
@ -19,9 +19,7 @@ import os
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...file_utils import is_sentencepiece_available
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
@ -212,47 +210,3 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0 + eos) * [0]
|
||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
self.prefix_tokens = []
|
||||
model_inputs = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
# set prefix_tokens for target text
|
||||
self.prefix_tokens = [self.pad_token_id]
|
||||
labels_and_decoder_mask = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
|
||||
self.prefix_tokens = []
|
||||
return model_inputs
|
||||
|
@ -738,80 +738,3 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "None",
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
r"""
|
||||
|
||||
Prepare a batch that can be passed directly to an instance of :class:`~transformers.AutoModelForSeq2SeqLM`.
|
||||
|
||||
Args:
|
||||
src_texts: (:obj:`List[str]`):
|
||||
List of documents to summarize or source language texts.
|
||||
tgt_texts: (:obj:`List[str]`, `optional`):
|
||||
List of summaries or target language texts.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts). If
|
||||
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
||||
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
||||
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries). If left unset or
|
||||
set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **labels** -- List of token ids for tgt_texts
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
||||
Otherwise, input_ids, attention_mask will be the only keys.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"If your model requires more than input_ids for a typical forward pass, you should implement this method. "
|
||||
"Returned keys should be [input_ids, attention_mask, labels]. See MarianTokenizer or T5Tokenizer for a "
|
||||
"reference implementation."
|
||||
)
|
||||
|
@ -23,6 +23,7 @@ import json
|
||||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict, UserDict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
@ -1473,68 +1474,6 @@ INIT_TOKENIZER_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
PREPARE_SEQ2SEQ_BATCH_DOCSTRING = """
|
||||
Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
||||
|
||||
Arguments:
|
||||
src_texts (:obj:`List[str]`):
|
||||
List of documents to summarize or source language texts.
|
||||
tgt_texts (:obj:`list`, `optional`):
|
||||
List of summaries or target language texts.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
|
||||
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
||||
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
||||
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
|
||||
to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
||||
|
||||
Return:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **labels** -- List of token ids for tgt_texts.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
||||
Otherwise, input_ids, attention_mask will be the only keys.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
|
||||
class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
"""
|
||||
@ -3252,3 +3191,113 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
"indexing errors".format(len(ids), self.model_max_length)
|
||||
)
|
||||
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
yield
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
||||
|
||||
Arguments:
|
||||
src_texts (:obj:`List[str]`):
|
||||
List of documents to summarize or source language texts.
|
||||
tgt_texts (:obj:`list`, `optional`):
|
||||
List of summaries or target language texts.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
|
||||
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
||||
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
||||
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
|
||||
to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
||||
|
||||
Return:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **labels** -- List of token ids for tgt_texts.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
||||
Otherwise, input_ids, attention_mask will be the only keys.
|
||||
"""
|
||||
# mBART-specific kwargs that should be ignored by other models.
|
||||
kwargs.pop("src_lang", None)
|
||||
kwargs.pop("tgt_lang", None)
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
with self.as_target_tokenizer():
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
return model_inputs
|
||||
|
@ -508,12 +508,6 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
||||
def test_batch_generation_en_ROMANCE_multi(self):
|
||||
self._assert_generated_batch_equal_expected()
|
||||
|
||||
def test_tokenizer_handles_empty(self):
|
||||
normalized = self.tokenizer.normalize("")
|
||||
self.assertIsInstance(normalized, str)
|
||||
with self.assertRaises(ValueError):
|
||||
self.tokenizer.prepare_seq2seq_batch([""], return_tensors="pt")
|
||||
|
||||
@slow
|
||||
def test_pipeline(self):
|
||||
device = 0 if torch_device == "cuda" else -1
|
||||
|
@ -83,6 +83,7 @@ class TokenizerTesterMixin:
|
||||
from_pretrained_kwargs = None
|
||||
from_pretrained_filter = None
|
||||
from_pretrained_vocab_key = "vocab_file"
|
||||
test_seq2seq = True
|
||||
|
||||
def setUp(self) -> None:
|
||||
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
|
||||
@ -1799,10 +1800,11 @@ class TokenizerTesterMixin:
|
||||
|
||||
@require_torch
|
||||
def test_prepare_seq2seq_batch(self):
|
||||
if not self.test_seq2seq:
|
||||
return
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
if not hasattr(tokenizer, "prepare_seq2seq_batch"):
|
||||
return
|
||||
# Longer text that will definitely require truncation.
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
|
@ -26,6 +26,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = CTRLTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -32,6 +32,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
rust_tokenizer_class = GPT2TokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_kwargs = {"add_prefix_space": True}
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -31,6 +31,7 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = OpenAIGPTTokenizer
|
||||
rust_tokenizer_class = OpenAIGPTTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -33,6 +33,7 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = ReformerTokenizer
|
||||
rust_tokenizer_class = ReformerTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -44,6 +44,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
test_rust_tokenizer = False
|
||||
space_between_special_tokens = True
|
||||
from_pretrained_filter = filter_non_english
|
||||
test_seq2seq = False
|
||||
|
||||
def get_table(
|
||||
self,
|
||||
|
@ -26,6 +26,7 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = TransfoXLTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
Loading…
Reference in New Issue
Block a user