mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 15:28:59 +06:00
[Bart: example] drop columns that are exclusively pad_token_id… (#3400)
* trim seq_len below 1024 if there are columns full of pad_token_id * Centralize trim_batch so SummarizationDataset can use it too
This commit is contained in:
parent
63f4d8cad0
commit
c10decf7a0
@ -1997,3 +1997,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
||||
files = self._tokenizer.save(folder, name=file)
|
||||
|
||||
return tuple(files)
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids, pad_token_id, attention_mask=None,
|
||||
):
|
||||
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||
if attention_mask is None:
|
||||
return input_ids[:, keep_column_mask]
|
||||
else:
|
||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||
|
Loading…
Reference in New Issue
Block a user