diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index cdb5e2839ac..1379912757a 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1997,3 +1997,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): files = self._tokenizer.save(folder, name=file) return tuple(files) + + +def trim_batch( + input_ids, pad_token_id, attention_mask=None, +): + """Remove columns that are populated exclusively by pad_token_id""" + keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) + if attention_mask is None: + return input_ids[:, keep_column_mask] + else: + return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])