fixed bug in run_mlm_flax_stream.py (#17203)

* fixed bug run_mlm_flax_stream.py

Fixed bug caused by an update to tokenizer keys introduced in recent transformers versions (between `4.6.2` and `4.18.0`) where additional keys were introduced to the tokenizer output.

* Update run_mlm_flax_stream.py

* adding missing paranthesis

* formatted to black

* remove cols from dataset instead

* reformat to black

* moved rem. columns to map

* formatted to black

Co-authored-by: KennethEnevoldsen <kennethcenevolsen@gmail.com>
This commit is contained in:
Kenneth Enevoldsen 2022-05-16 13:40:27 +02:00 committed by GitHub
parent 71abd3ade1
commit 71d18d0831
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -288,8 +288,10 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
tokenized_samples = next(train_iterator)
i += len(tokenized_samples["input_ids"])
# concatenate tokenized samples to list
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
# concatenate tokenized samples to list (excluding "id" and "text")
samples = {
k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
}
# Concatenated tokens are split to lists of length `max_seq_length`.
# Note that remainedr of % max_seq_length are thrown away.
@ -407,10 +409,7 @@ if __name__ == "__main__":
def tokenize_function(examples):
return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
tokenized_datasets = dataset.map(
tokenize_function,
batched=True,
)
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=list(dataset.features.keys()))
shuffle_seed = training_args.seed
tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)