Update PreTrainedTokenizerBase to check/handle batch length for text_pair parameter (#11486)

* Update tokenization_utils_base.py

* add assertion

* check batch len

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add error message

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Hamel Husain 2021-04-28 07:11:17 -07:00 committed by GitHub
parent 2d27900b5d
commit c0eb218a55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2279,6 +2279,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
)
if is_batched:
if isinstance(text_pair, str):
raise TypeError(
"when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as `text`."
)
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
return self.batch_encode_plus(
batch_text_or_text_pairs=batch_text_or_text_pairs,