mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
remove old methods
This commit is contained in:
parent
965b2565a0
commit
c6207d85b6
@ -305,37 +305,6 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
def input_fn_builder(features, seq_length, train_batch_size):
|
||||
# TODO: delete
|
||||
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
|
||||
|
||||
all_input_ids = [f.input_ids for feature in features]
|
||||
all_input_mask = [f.input_mask for feature in features]
|
||||
all_segment_ids = [f.segment_ids for feature in features]
|
||||
all_label_ids = [f.label_id for feature in features]
|
||||
|
||||
# for feature in features:
|
||||
# all_input_ids.append(feature.input_ids)
|
||||
# all_input_mask.append(feature.input_mask)
|
||||
# all_segment_ids.append(feature.segment_ids)
|
||||
# all_label_ids.append(feature.label_id)
|
||||
|
||||
input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.Long)
|
||||
input_mask_tensor = torch.tensor(all_input_mask, dtype=torch.Long)
|
||||
segment_tensor = torch.tensor(all_segment_ids, dtype=torch.Long)
|
||||
label_tensor = torch.tensor(all_label_ids, dtype=torch.Long)
|
||||
|
||||
train_data = TensorDataset(input_ids_tensor, input_mask_tensor,
|
||||
segment_tensor, label_tensor)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
else:
|
||||
train_sampler = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)
|
||||
|
||||
return train_dataloader
|
||||
|
||||
def accuracy(out, labels):
|
||||
outputs = np.argmax(out, axis=1)
|
||||
return np.sum(outputs==labels)
|
||||
|
Loading…
Reference in New Issue
Block a user