From 555b7d66c91b11292c3c46ba1408068bf7595d4e Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Thu, 1 Nov 2018 02:10:46 -0400 Subject: [PATCH] `input_fn_builder` WIP --- run_classifier_pytorch.py | 63 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index 17d490b7a8f..b3e3612542c 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -23,6 +23,7 @@ import os # import modeling_pytorch # import optimization import tokenization_pytorch +import torch import logging logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', @@ -381,4 +382,64 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): if len(tokens_a) > len(tokens_b): tokens_a.pop() else: - tokens_b.pop() \ No newline at end of file + tokens_b.pop() + + +def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, + labels, num_labels, use_one_hot_embeddings): + raise NotImplementedError() + + +def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, + num_train_steps, num_warmup_steps, + use_one_hot_embeddings): + raise NotImplementedError() + ### ATTENTION - I removed the `use_tpu` argument + + +def input_fn_builder(features, seq_length, is_training, drop_remainder): + """Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ### + + all_input_ids = [] + all_input_mask = [] + all_segment_ids = [] + all_label_ids = [] + + for feature in features: + all_input_ids.append(feature.input_ids) + all_input_mask.append(feature.input_mask) + all_segment_ids.append(feature.segment_ids) + all_label_ids.append(feature.label_id) + + def input_fn(params): + """The actual input function.""" + batch_size = params["batch_size"] + + num_examples = len(features) + + # This is for demo purposes and does NOT scale to large data sets. We do + # not use Dataset.from_generator() because that uses tf.py_func which is + # not TPU compatible. The right way to load data is with TFRecordReader. + d = tf.data.Dataset.from_tensor_slices({ + "input_ids": + torch.Tensor(all_input_ids, size=[num_examples, seq_length], + dtype=torch.int32, requires_grad=False), + "input_mask": + torch.Tensor(all_input_mask, size=[num_examples, seq_length], + dtype=torch.int32, requires_grad=False), + "segment_ids": + torch.Tensor(all_segment_ids, size=[num_examples, seq_length], + dtype=torch.int32, requires_grad=False), + "label_ids": + torch.Tensor(all_label_ids, size=[num_examples], + dtype=torch.int32, requires_grad=False) + }) + + if is_training: + d = d.repeat() + d = d.shuffle(buffer_size=100) + + d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) + return d + + return input_fn