mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Merge pull request #515 from Rocketknight1/master
Fix --reduce_memory in finetune_on_pregenerated
This commit is contained in:
commit
c36cca075a
@ -74,7 +74,7 @@ class PregeneratedDataset(Dataset):
|
|||||||
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
|
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
|
||||||
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
||||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||||
segment_ids = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
segment_ids = np.memmap(filename=self.working_dir/'segment_ids.memmap',
|
||||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||||
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
|
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
|
||||||
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
|
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
|
||||||
@ -283,7 +283,7 @@ def main():
|
|||||||
model.train()
|
model.train()
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
|
epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
|
||||||
num_data_epochs=num_data_epochs)
|
num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory)
|
||||||
if args.local_rank == -1:
|
if args.local_rank == -1:
|
||||||
train_sampler = RandomSampler(epoch_dataset)
|
train_sampler = RandomSampler(epoch_dataset)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user