mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Specify checkpoint in saved file for run_lm_finetuning.py
This commit is contained in:
parent
e18f786cd5
commit
d7929899da
@ -63,10 +63,10 @@ MODEL_CLASSES = {
|
|||||||
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
def __init__(self, tokenizer, file_path='train', block_size=512):
|
def __init__(self, tokenizer, args, file_path='train', block_size=512):
|
||||||
assert os.path.isfile(file_path)
|
assert os.path.isfile(file_path)
|
||||||
directory, filename = os.path.split(file_path)
|
directory, filename = os.path.split(file_path)
|
||||||
cached_features_file = os.path.join(directory, 'cached_lm_' + str(block_size) + '_' + filename)
|
cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_' + filename)
|
||||||
|
|
||||||
if os.path.exists(cached_features_file):
|
if os.path.exists(cached_features_file):
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
@ -99,7 +99,7 @@ class TextDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
||||||
dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
dataset = TextDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user