fix(run_language_modeling): use arg overwrite_cache (#4407)

This commit is contained in:
Boris Dayma 2020-05-18 10:37:35 -05:00 committed by GitHub
parent d39bf0ac2d
commit d9ece8233d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -120,7 +120,9 @@ def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, eva
if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else:
return TextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
return TextDataset(
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache
)
def main():
@ -216,6 +218,7 @@ def main():
data_args.block_size = min(data_args.block_size, tokenizer.max_len)
# Get datasets
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
data_collator = DataCollatorForLanguageModeling(