diff --git a/examples/lm_finetuning/finetune_on_pregenerated.py b/examples/lm_finetuning/finetune_on_pregenerated.py index 6a633245027..5c3051f5009 100644 --- a/examples/lm_finetuning/finetune_on_pregenerated.py +++ b/examples/lm_finetuning/finetune_on_pregenerated.py @@ -74,7 +74,7 @@ class PregeneratedDataset(Dataset): mode='w+', dtype=np.int32, shape=(num_samples, seq_len)) input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap', shape=(num_samples, seq_len), mode='w+', dtype=np.bool) - segment_ids = np.memmap(filename=self.working_dir/'input_masks.memmap', + segment_ids = np.memmap(filename=self.working_dir/'segment_ids.memmap', shape=(num_samples, seq_len), mode='w+', dtype=np.bool) lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap', shape=(num_samples, seq_len), mode='w+', dtype=np.int32) @@ -283,7 +283,7 @@ def main(): model.train() for epoch in range(args.epochs): epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer, - num_data_epochs=num_data_epochs) + num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory) if args.local_rank == -1: train_sampler = RandomSampler(epoch_dataset) else: