remove old methods

This commit is contained in:
thomwolf 2018-11-04 15:34:00 +01:00
parent 965b2565a0
commit c6207d85b6

View File

@ -305,37 +305,6 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
else:
tokens_b.pop()
def input_fn_builder(features, seq_length, train_batch_size):
# TODO: delete
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
all_input_ids = [f.input_ids for feature in features]
all_input_mask = [f.input_mask for feature in features]
all_segment_ids = [f.segment_ids for feature in features]
all_label_ids = [f.label_id for feature in features]
# for feature in features:
# all_input_ids.append(feature.input_ids)
# all_input_mask.append(feature.input_mask)
# all_segment_ids.append(feature.segment_ids)
# all_label_ids.append(feature.label_id)
input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.Long)
input_mask_tensor = torch.tensor(all_input_mask, dtype=torch.Long)
segment_tensor = torch.tensor(all_segment_ids, dtype=torch.Long)
label_tensor = torch.tensor(all_label_ids, dtype=torch.Long)
train_data = TensorDataset(input_ids_tensor, input_mask_tensor,
segment_tensor, label_tensor)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)
return train_dataloader
def accuracy(out, labels):
outputs = np.argmax(out, axis=1)
return np.sum(outputs==labels)