diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index d072ae3f04f..534dbfe3ae1 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -434,7 +434,7 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, ### ATTENTION - I removed the `use_tpu` argument -def input_fn_builder(features, seq_length, is_training, eval_drop_remainder): +def input_fn_builder(features, seq_length, is_training, drop_remainder): """Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ### all_input_ids = [] @@ -462,7 +462,7 @@ def input_fn_builder(features, seq_length, is_training, eval_drop_remainder): "label_ids": torch.IntTensor(all_label_ids, device=device) }) - shuffle = True if training else False + shuffle = True if is_training else False d = torch.utils.data.DataLoader(dataset=d, batch_size=batch_size, shuffle=shuffle,drop_last=drop_remainder) # Cf https://pytorch.org/tutorials/beginner/data_loading_tutorial.html