From c10decf7a0094694c18feab5a0cb15a41e6ee3df Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 26 Mar 2020 19:33:54 -0400 Subject: [PATCH] =?UTF-8?q?[Bart:=20example]=20drop=20columns=20that=20are?= =?UTF-8?q?=20exclusively=20pad=5Ftoken=5Fid=E2=80=A6=20(#3400)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * trim seq_len below 1024 if there are columns full of pad_token_id * Centralize trim_batch so SummarizationDataset can use it too --- src/transformers/tokenization_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index cdb5e2839ac..1379912757a 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1997,3 +1997,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): files = self._tokenizer.save(folder, name=file) return tuple(files) + + +def trim_batch( + input_ids, pad_token_id, attention_mask=None, +): + """Remove columns that are populated exclusively by pad_token_id""" + keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) + if attention_mask is None: + return input_ids[:, keep_column_mask] + else: + return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])