mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Align with run_squad
+ fix some errors
This commit is contained in:
parent
45634f87f8
commit
ebd45980a0
@ -13,7 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" This is the exact same script as `examples/run_squad.py` (as of 2019, October 4th) with an additional and optional step of distillation."""
|
""" This is the exact same script as `examples/run_squad.py` (as of 2020, January 8th) with an additional and optional step of distillation."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
@ -60,6 +60,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
ALL_MODELS = sum(
|
||||||
@ -114,11 +115,21 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
scheduler = get_linear_schedule_with_warmup(
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
)
|
)
|
||||||
if args.fp16:
|
|
||||||
|
# Check if saved optimizer or scheduler states exist
|
||||||
|
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||||
|
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||||
|
):
|
||||||
|
# Load in optimizer and scheduler states
|
||||||
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||||
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||||
|
|
||||||
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
|
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||||
|
|
||||||
# multi-gpu training (should be after apex fp16 initialization)
|
# multi-gpu training (should be after apex fp16 initialization)
|
||||||
@ -145,18 +156,47 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 1
|
||||||
|
epochs_trained = 0
|
||||||
|
steps_trained_in_current_epoch = 0
|
||||||
|
# Check if continuing training from a checkpoint
|
||||||
|
if os.path.exists(args.model_name_or_path):
|
||||||
|
try:
|
||||||
|
# set global_step to gobal_step of last saved checkpoint from model path
|
||||||
|
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
|
||||||
|
global_step = int(checkpoint_suffix)
|
||||||
|
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||||
|
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||||
|
logger.info(" Continuing training from global step %d", global_step)
|
||||||
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||||
|
except ValueError:
|
||||||
|
logger.info(" Starting fine-tuning.")
|
||||||
|
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(
|
||||||
set_seed(args) # Added here for reproductibility
|
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||||
|
)
|
||||||
|
# Added here for reproductibility
|
||||||
|
set_seed(args)
|
||||||
|
|
||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
|
|
||||||
|
# Skip past any already trained steps if resuming training
|
||||||
|
if steps_trained_in_current_epoch > 0:
|
||||||
|
steps_trained_in_current_epoch -= 1
|
||||||
|
continue
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
if teacher is not None:
|
if teacher is not None:
|
||||||
teacher.eval()
|
teacher.eval()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": batch[0],
|
"input_ids": batch[0],
|
||||||
"attention_mask": batch[1],
|
"attention_mask": batch[1],
|
||||||
@ -167,6 +207,8 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
|
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
|
||||||
if args.model_type in ["xlnet", "xlm"]:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||||
|
if args.version_2_with_negative:
|
||||||
|
inputs.update({"is_impossible": batch[7]})
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss, start_logits_stu, end_logits_stu = outputs
|
loss, start_logits_stu, end_logits_stu = outputs
|
||||||
|
|
||||||
@ -219,11 +261,10 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
# Log metrics
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
# Only evaluate when single GPU otherwise metrics may not average well
|
||||||
if (
|
if args.local_rank == -1 and args.evaluate_during_training:
|
||||||
args.local_rank == -1 and args.evaluate_during_training
|
|
||||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
|
||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||||
@ -240,9 +281,15 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
model.module if hasattr(model, "module") else model
|
model.module if hasattr(model, "module") else model
|
||||||
) # Take care of distributed/parallel training
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
|
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
|
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
epoch_iterator.close()
|
epoch_iterator.close()
|
||||||
break
|
break
|
||||||
@ -263,18 +310,27 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||||
|
|
||||||
# Note that DistributedSampler samples randomly
|
# Note that DistributedSampler samples randomly
|
||||||
eval_sampler = SequentialSampler(dataset)
|
eval_sampler = SequentialSampler(dataset)
|
||||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
|
# multi-gpu evaluate
|
||||||
|
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
# Eval!
|
# Eval!
|
||||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||||
logger.info(" Num examples = %d", len(dataset))
|
logger.info(" Num examples = %d", len(dataset))
|
||||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||||
|
|
||||||
all_results = []
|
all_results = []
|
||||||
|
start_time = timeit.default_timer()
|
||||||
|
|
||||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
model.eval()
|
model.eval()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
|
||||||
if args.model_type != "distilbert":
|
if args.model_type != "distilbert":
|
||||||
@ -282,6 +338,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
example_indices = batch[3]
|
example_indices = batch[3]
|
||||||
if args.model_type in ["xlnet", "xlm"]:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
for i, example_index in enumerate(example_indices):
|
for i, example_index in enumerate(example_indices):
|
||||||
@ -314,9 +371,13 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
|
|
||||||
all_results.append(result)
|
all_results.append(result)
|
||||||
|
|
||||||
|
evalTime = timeit.default_timer() - start_time
|
||||||
|
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
|
||||||
|
|
||||||
# Compute predictions
|
# Compute predictions
|
||||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||||
|
|
||||||
if args.version_2_with_negative:
|
if args.version_2_with_negative:
|
||||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||||
else:
|
else:
|
||||||
@ -333,7 +394,6 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
output_prediction_file,
|
output_prediction_file,
|
||||||
output_nbest_file,
|
output_nbest_file,
|
||||||
output_null_log_odds_file,
|
output_null_log_odds_file,
|
||||||
args.predict_file,
|
|
||||||
model.config.start_n_top,
|
model.config.start_n_top,
|
||||||
model.config.end_n_top,
|
model.config.end_n_top,
|
||||||
args.version_2_with_negative,
|
args.version_2_with_negative,
|
||||||
@ -364,7 +424,8 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||||
if args.local_rank not in [-1, 0] and not evaluate:
|
if args.local_rank not in [-1, 0] and not evaluate:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
input_file = args.predict_file if evaluate else args.train_file
|
input_file = args.predict_file if evaluate else args.train_file
|
||||||
@ -395,9 +456,9 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
logger.info("Creating features from dataset file at %s", input_file)
|
logger.info("Creating features from dataset file at %s", input_file)
|
||||||
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
||||||
if evaluate:
|
if evaluate:
|
||||||
examples = processor.get_dev_examples(None, filename=args.predict_file)
|
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
|
||||||
else:
|
else:
|
||||||
examples = processor.get_train_examples(None, filename=args.train_file)
|
examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
|
||||||
|
|
||||||
features, dataset = squad_convert_examples_to_features(
|
features, dataset = squad_convert_examples_to_features(
|
||||||
examples=examples,
|
examples=examples,
|
||||||
@ -407,13 +468,16 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
max_query_length=args.max_query_length,
|
max_query_length=args.max_query_length,
|
||||||
is_training=not evaluate,
|
is_training=not evaluate,
|
||||||
return_dataset="pt",
|
return_dataset="pt",
|
||||||
|
threads=args.threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
||||||
|
|
||||||
if args.local_rank == 0 and not evaluate:
|
if args.local_rank == 0 and not evaluate:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
if output_examples:
|
if output_examples:
|
||||||
return dataset, examples, features
|
return dataset, examples, features
|
||||||
@ -424,16 +488,6 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
# Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
|
||||||
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--predict_file",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_type",
|
"--model_type",
|
||||||
default=None,
|
default=None,
|
||||||
@ -480,6 +534,27 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Other parameters
|
# Other parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="The input data dir. Should contain the .json files for the task."
|
||||||
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_file",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="The input training file. If a data dir is specified, will look for the file there"
|
||||||
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--predict_file",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="The input evaluation file. If a data dir is specified, will look for the file there"
|
||||||
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
)
|
)
|
||||||
@ -548,7 +623,7 @@ def main():
|
|||||||
default=1,
|
default=1,
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -612,6 +687,8 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
|
||||||
|
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -666,7 +743,8 @@ def main():
|
|||||||
|
|
||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
# Make sure only the first process in distributed training will download model & vocab
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
@ -703,12 +781,24 @@ def main():
|
|||||||
teacher = None
|
teacher = None
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
# Make sure only the first process in distributed training will download model & vocab
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
|
||||||
|
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||||
|
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||||
|
# remove the need for this code, but it is still valid.
|
||||||
|
if args.fp16:
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
|
||||||
|
apex.amp.register_half_function(torch, "einsum")
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
||||||
@ -734,15 +824,15 @@ def main():
|
|||||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir, cache_dir=args.cache_dir if args.cache_dir else None)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
args.output_dir, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None
|
|
||||||
)
|
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
|
if args.do_train:
|
||||||
|
logger.info("Loading checkpoints saved during training for evaluation")
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(
|
checkpoints = list(
|
||||||
@ -755,7 +845,7 @@ def main():
|
|||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
# Reload the model
|
# Reload the model
|
||||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
model = model_class.from_pretrained(checkpoint, cache_dir=args.cache_dir if args.cache_dir else None)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
|
Loading…
Reference in New Issue
Block a user