mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Specify checkpoint in saved file for run_lm_finetuning.py
This commit is contained in:
parent
e18f786cd5
commit
d7929899da
@ -63,10 +63,10 @@ MODEL_CLASSES = {
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
def __init__(self, tokenizer, file_path='train', block_size=512):
|
||||
def __init__(self, tokenizer, args, file_path='train', block_size=512):
|
||||
assert os.path.isfile(file_path)
|
||||
directory, filename = os.path.split(file_path)
|
||||
cached_features_file = os.path.join(directory, 'cached_lm_' + str(block_size) + '_' + filename)
|
||||
cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_' + filename)
|
||||
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
@ -99,7 +99,7 @@ class TextDataset(Dataset):
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
||||
dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
||||
dataset = TextDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
||||
return dataset
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user