mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix merge
This commit is contained in:
parent
562f864038
commit
8a2be93b4e
@ -374,24 +374,6 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||||
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||
if evaluate:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_example_index, all_cls_index, all_p_mask)
|
||||
else:
|
||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||
all_is_impossible = torch.tensor([1.0 if f.is_impossible == True else 0.0 for f in features], dtype=torch.float)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_start_positions, all_end_positions,
|
||||
all_cls_index, all_p_mask, all_is_impossible)
|
||||
|
||||
if output_examples:
|
||||
return dataset, examples, features
|
||||
return dataset
|
||||
|
Loading…
Reference in New Issue
Block a user