Merge pull request #515 from Rocketknight1/master

Fix --reduce_memory in finetune_on_pregenerated
This commit is contained in:
Thomas Wolf 2019-04-23 10:30:23 +02:00 committed by GitHub
commit c36cca075a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -74,7 +74,7 @@ class PregeneratedDataset(Dataset):
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
segment_ids = np.memmap(filename=self.working_dir/'input_masks.memmap',
segment_ids = np.memmap(filename=self.working_dir/'segment_ids.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
@ -283,7 +283,7 @@ def main():
model.train()
for epoch in range(args.epochs):
epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
num_data_epochs=num_data_epochs)
num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory)
if args.local_rank == -1:
train_sampler = RandomSampler(epoch_dataset)
else: