mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Multi-Gpu loss - Cleaning
This commit is contained in:
parent
5de1517d6b
commit
3c24e4bef1
@ -27,6 +27,7 @@ import math
|
||||
import os
|
||||
from tqdm import tqdm, trange
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
@ -718,23 +719,6 @@ def main():
|
||||
parser.add_argument("--max_answer_length", default=30, type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.")
|
||||
|
||||
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
|
||||
# parser.add_argument("--use_tpu", default=False, action='store_true', help="Whether to use TPU or GPU/CPU.")
|
||||
# parser.add_argument("--tpu_name", default=None, type=str,
|
||||
# help="The Cloud TPU to use for training. This should be either the name used when creating the "
|
||||
# "Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.")
|
||||
# parser.add_argument("--tpu_zone", default=None, type=str,
|
||||
# help="[Optional] GCE zone where the Cloud TPU is located in. If not specified, we will attempt "
|
||||
# "to automatically detect the GCE project from metadata.")
|
||||
# parser.add_argument("--gcp_project", default=None, type=str,
|
||||
# help="[Optional] Project name for the Cloud TPU-enabled project. If not specified, we will attempt "
|
||||
# "to automatically detect the GCE project from metadata.")
|
||||
# parser.add_argument("--master", default=None, type=str, help="[Optional] TensorFlow master URL.")
|
||||
# parser.add_argument("--num_tpu_cores", default=8, type=int, help="Only used if `use_tpu` is True. "
|
||||
# "Total number of TPU cores to use.")
|
||||
### END - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
|
||||
|
||||
parser.add_argument("--verbose_logging", default=False, action='store_true',
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||
@ -836,16 +820,12 @@ def main():
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
logger.info(" Num steps = %d", num_train_steps)
|
||||
|
||||
logger.info("HHHHH Loading data")
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
#all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
|
||||
|
||||
logger.info("HHHHH Creating dataset")
|
||||
#train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
@ -869,15 +849,11 @@ def main():
|
||||
start_positions = start_positions.view(-1, 1)
|
||||
end_positions = end_positions.view(-1, 1)
|
||||
|
||||
logger.info("HHHHH Forward")
|
||||
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
||||
model.zero_grad()
|
||||
logger.info("HHHHH Backward")
|
||||
loss.backward()
|
||||
logger.info("HHHHH Loading data")
|
||||
loss.mean().backward()
|
||||
optimizer.step()
|
||||
global_step += 1
|
||||
logger.info("Done %s steps", global_step)
|
||||
|
||||
if args.do_predict:
|
||||
eval_examples = read_squad_examples(
|
||||
@ -898,10 +874,8 @@ def main():
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
#all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
|
||||
#eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index)
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
||||
if args.local_rank == -1:
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
@ -912,7 +886,6 @@ def main():
|
||||
model.eval()
|
||||
all_results = []
|
||||
logger.info("Start evaulating")
|
||||
#for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
|
||||
for input_ids, input_mask, segment_ids, example_index in eval_dataloader:
|
||||
if len(all_results) % 1000 == 0:
|
||||
logger.info("Processing example: %d" % (len(all_results)))
|
||||
@ -924,9 +897,7 @@ def main():
|
||||
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
|
||||
|
||||
unique_id = [int(eval_features[e.item()].unique_id) for e in example_index]
|
||||
#start_logits = [x.item() for x in start_logits]
|
||||
start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits]
|
||||
#end_logits = [x.item() for x in end_logits]
|
||||
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits]
|
||||
for idx, i in enumerate(unique_id):
|
||||
s = [float(x) for x in start_logits[idx]]
|
||||
@ -938,11 +909,6 @@ def main():
|
||||
end_logits=e
|
||||
)
|
||||
)
|
||||
# all_results.append(
|
||||
# RawResult(
|
||||
# unique_id=unique_id,
|
||||
# start_logits=start_logits,
|
||||
# end_logits=end_logits))
|
||||
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
|
||||
|
Loading…
Reference in New Issue
Block a user