mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
FIX errors in loading Dataset in run_squad_pytorch
This commit is contained in:
parent
72d69a4ef4
commit
833c3a7a25
@ -818,9 +818,12 @@ def main():
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
#all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
|
||||
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
#train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
else:
|
||||
@ -829,13 +832,16 @@ def main():
|
||||
|
||||
model.train()
|
||||
for epoch in range(int(args.num_train_epochs)):
|
||||
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||
#for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||
for input_ids, input_mask, segment_ids, start_positions, end_positions in train_dataloader:
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
#label_ids = label_ids.to(device)
|
||||
start_positions = start_positions.to(device)
|
||||
end_positions = start_positions.to(device)
|
||||
|
||||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
global_step += 1
|
||||
|
Loading…
Reference in New Issue
Block a user