typos in input_fn_builder

This commit is contained in:
VictorSanh 2018-11-01 14:17:55 -04:00
parent 836faed985
commit d3a8df6b9f

View File

@ -434,7 +434,7 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
### ATTENTION - I removed the `use_tpu` argument
def input_fn_builder(features, seq_length, is_training, eval_drop_remainder):
def input_fn_builder(features, seq_length, is_training, drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
all_input_ids = []
@ -462,7 +462,7 @@ def input_fn_builder(features, seq_length, is_training, eval_drop_remainder):
"label_ids": torch.IntTensor(all_label_ids, device=device)
})
shuffle = True if training else False
shuffle = True if is_training else False
d = torch.utils.data.DataLoader(dataset=d, batch_size=batch_size,
shuffle=shuffle,drop_last=drop_remainder)
# Cf https://pytorch.org/tutorials/beginner/data_loading_tutorial.html