mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
typos in input_fn_builder
This commit is contained in:
parent
836faed985
commit
d3a8df6b9f
@ -434,7 +434,7 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
|
||||
### ATTENTION - I removed the `use_tpu` argument
|
||||
|
||||
|
||||
def input_fn_builder(features, seq_length, is_training, eval_drop_remainder):
|
||||
def input_fn_builder(features, seq_length, is_training, drop_remainder):
|
||||
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
|
||||
|
||||
all_input_ids = []
|
||||
@ -462,7 +462,7 @@ def input_fn_builder(features, seq_length, is_training, eval_drop_remainder):
|
||||
"label_ids": torch.IntTensor(all_label_ids, device=device)
|
||||
})
|
||||
|
||||
shuffle = True if training else False
|
||||
shuffle = True if is_training else False
|
||||
d = torch.utils.data.DataLoader(dataset=d, batch_size=batch_size,
|
||||
shuffle=shuffle,drop_last=drop_remainder)
|
||||
# Cf https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
|
||||
|
Loading…
Reference in New Issue
Block a user