mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
fix(run_language_modeling): use arg overwrite_cache (#4407)
This commit is contained in:
parent
d39bf0ac2d
commit
d9ece8233d
@ -120,7 +120,9 @@ def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, eva
|
||||
if args.line_by_line:
|
||||
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||
else:
|
||||
return TextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||
return TextDataset(
|
||||
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
@ -216,6 +218,7 @@ def main():
|
||||
data_args.block_size = min(data_args.block_size, tokenizer.max_len)
|
||||
|
||||
# Get datasets
|
||||
|
||||
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
||||
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
|
Loading…
Reference in New Issue
Block a user