mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 18:48:24 +06:00
input_fn_builder
WIP
This commit is contained in:
parent
f8e347b557
commit
555b7d66c9
@ -23,6 +23,7 @@ import os
|
|||||||
# import modeling_pytorch
|
# import modeling_pytorch
|
||||||
# import optimization
|
# import optimization
|
||||||
import tokenization_pytorch
|
import tokenization_pytorch
|
||||||
|
import torch
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
@ -382,3 +383,63 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
|||||||
tokens_a.pop()
|
tokens_a.pop()
|
||||||
else:
|
else:
|
||||||
tokens_b.pop()
|
tokens_b.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
|
||||||
|
labels, num_labels, use_one_hot_embeddings):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
|
||||||
|
num_train_steps, num_warmup_steps,
|
||||||
|
use_one_hot_embeddings):
|
||||||
|
raise NotImplementedError()
|
||||||
|
### ATTENTION - I removed the `use_tpu` argument
|
||||||
|
|
||||||
|
|
||||||
|
def input_fn_builder(features, seq_length, is_training, drop_remainder):
|
||||||
|
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
|
||||||
|
|
||||||
|
all_input_ids = []
|
||||||
|
all_input_mask = []
|
||||||
|
all_segment_ids = []
|
||||||
|
all_label_ids = []
|
||||||
|
|
||||||
|
for feature in features:
|
||||||
|
all_input_ids.append(feature.input_ids)
|
||||||
|
all_input_mask.append(feature.input_mask)
|
||||||
|
all_segment_ids.append(feature.segment_ids)
|
||||||
|
all_label_ids.append(feature.label_id)
|
||||||
|
|
||||||
|
def input_fn(params):
|
||||||
|
"""The actual input function."""
|
||||||
|
batch_size = params["batch_size"]
|
||||||
|
|
||||||
|
num_examples = len(features)
|
||||||
|
|
||||||
|
# This is for demo purposes and does NOT scale to large data sets. We do
|
||||||
|
# not use Dataset.from_generator() because that uses tf.py_func which is
|
||||||
|
# not TPU compatible. The right way to load data is with TFRecordReader.
|
||||||
|
d = tf.data.Dataset.from_tensor_slices({
|
||||||
|
"input_ids":
|
||||||
|
torch.Tensor(all_input_ids, size=[num_examples, seq_length],
|
||||||
|
dtype=torch.int32, requires_grad=False),
|
||||||
|
"input_mask":
|
||||||
|
torch.Tensor(all_input_mask, size=[num_examples, seq_length],
|
||||||
|
dtype=torch.int32, requires_grad=False),
|
||||||
|
"segment_ids":
|
||||||
|
torch.Tensor(all_segment_ids, size=[num_examples, seq_length],
|
||||||
|
dtype=torch.int32, requires_grad=False),
|
||||||
|
"label_ids":
|
||||||
|
torch.Tensor(all_label_ids, size=[num_examples],
|
||||||
|
dtype=torch.int32, requires_grad=False)
|
||||||
|
})
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
d = d.repeat()
|
||||||
|
d = d.shuffle(buffer_size=100)
|
||||||
|
|
||||||
|
d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
|
||||||
|
return d
|
||||||
|
|
||||||
|
return input_fn
|
||||||
|
Loading…
Reference in New Issue
Block a user