[Bart: example] drop columns that are exclusively pad_token_id… (#3400)

* trim seq_len below 1024 if there are columns full of pad_token_id
* Centralize trim_batch so SummarizationDataset can use it too
This commit is contained in:
Sam Shleifer 2020-03-26 19:33:54 -04:00 committed by GitHub
parent 63f4d8cad0
commit c10decf7a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1997,3 +1997,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
files = self._tokenizer.save(folder, name=file)
return tuple(files)
def trim_batch(
input_ids, pad_token_id, attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])