mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix dataset shuffling for Distributed training (#huggingface#3721) (#3766)
This commit is contained in:
parent
7972a4019f
commit
5ebd898953
@ -317,8 +317,12 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||
)
|
||||
set_seed(args) # Added here for reproducibility
|
||||
for _ in train_iterator:
|
||||
for epoch in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
|
||||
if args.local_rank != -1:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
|
Loading…
Reference in New Issue
Block a user