mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-26 07:49:01 +06:00
[Bart: example] drop columns that are exclusively pad_token_id… (#3400)
* trim seq_len below 1024 if there are columns full of pad_token_id * Centralize trim_batch so SummarizationDataset can use it too
This commit is contained in:
parent
63f4d8cad0
commit
c10decf7a0
@ -1997,3 +1997,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
|||||||
files = self._tokenizer.save(folder, name=file)
|
files = self._tokenizer.save(folder, name=file)
|
||||||
|
|
||||||
return tuple(files)
|
return tuple(files)
|
||||||
|
|
||||||
|
|
||||||
|
def trim_batch(
|
||||||
|
input_ids, pad_token_id, attention_mask=None,
|
||||||
|
):
|
||||||
|
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||||
|
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||||
|
if attention_mask is None:
|
||||||
|
return input_ids[:, keep_column_mask]
|
||||||
|
else:
|
||||||
|
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||||
|
Loading…
Reference in New Issue
Block a user