mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Merge pull request #515 from Rocketknight1/master
Fix --reduce_memory in finetune_on_pregenerated
This commit is contained in:
commit
c36cca075a
@ -74,7 +74,7 @@ class PregeneratedDataset(Dataset):
|
||||
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
|
||||
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||
segment_ids = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
||||
segment_ids = np.memmap(filename=self.working_dir/'segment_ids.memmap',
|
||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
|
||||
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
|
||||
@ -283,7 +283,7 @@ def main():
|
||||
model.train()
|
||||
for epoch in range(args.epochs):
|
||||
epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
|
||||
num_data_epochs=num_data_epochs)
|
||||
num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(epoch_dataset)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user