mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
re-format
This commit is contained in:
parent
c8731b9583
commit
2fb9a934b4
@ -1,20 +1,22 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from collections import namedtuple
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from collections import namedtuple
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_transformers.modeling_bert import BertForPreTraining
|
from pytorch_transformers.modeling_bert import BertForPreTraining
|
||||||
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
|
|
||||||
from pytorch_transformers.tokenization_bert import BertTokenizer
|
from pytorch_transformers.tokenization_bert import BertTokenizer
|
||||||
|
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
|
||||||
|
|
||||||
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
|
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
|
||||||
|
|
||||||
@ -70,16 +72,16 @@ class PregeneratedDataset(Dataset):
|
|||||||
if reduce_memory:
|
if reduce_memory:
|
||||||
self.temp_dir = TemporaryDirectory()
|
self.temp_dir = TemporaryDirectory()
|
||||||
self.working_dir = Path(self.temp_dir.name)
|
self.working_dir = Path(self.temp_dir.name)
|
||||||
input_ids = np.memmap(filename=self.working_dir / 'input_ids.memmap',
|
input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap',
|
||||||
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
|
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
|
||||||
input_masks = np.memmap(filename=self.working_dir / 'input_masks.memmap',
|
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
||||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||||
segment_ids = np.memmap(filename=self.working_dir / 'segment_ids.memmap',
|
segment_ids = np.memmap(filename=self.working_dir/'segment_ids.memmap',
|
||||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||||
lm_label_ids = np.memmap(filename=self.working_dir / 'lm_label_ids.memmap',
|
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
|
||||||
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
|
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
|
||||||
lm_label_ids[:] = -1
|
lm_label_ids[:] = -1
|
||||||
is_nexts = np.memmap(filename=self.working_dir / 'is_nexts.memmap',
|
is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap',
|
||||||
shape=(num_samples,), mode='w+', dtype=np.bool)
|
shape=(num_samples,), mode='w+', dtype=np.bool)
|
||||||
else:
|
else:
|
||||||
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
|
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
|
||||||
@ -123,8 +125,7 @@ def main():
|
|||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument('--pregenerated_data', type=Path, required=True)
|
parser.add_argument('--pregenerated_data', type=Path, required=True)
|
||||||
parser.add_argument('--output_dir', type=Path, required=True)
|
parser.add_argument('--output_dir', type=Path, required=True)
|
||||||
parser.add_argument("--bert_model", type=str, required=True,
|
parser.add_argument("--bert_model", type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
|
||||||
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||||
parser.add_argument("--do_lower_case", action="store_true")
|
parser.add_argument("--do_lower_case", action="store_true")
|
||||||
parser.add_argument("--reduce_memory", action="store_true",
|
parser.add_argument("--reduce_memory", action="store_true",
|
||||||
@ -152,14 +153,14 @@ def main():
|
|||||||
parser.add_argument('--loss_scale',
|
parser.add_argument('--loss_scale',
|
||||||
type=float, default=0,
|
type=float, default=0,
|
||||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||||
"0 (default value): dynamic loss scaling.\n"
|
"0 (default value): dynamic loss scaling.\n"
|
||||||
"Positive power of 2: static loss scaling value.\n")
|
"Positive power of 2: static loss scaling value.\n")
|
||||||
parser.add_argument("--warmup_steps",
|
parser.add_argument("--warmup_steps",
|
||||||
default=0,
|
default=0,
|
||||||
type=int,
|
type=int,
|
||||||
help="Linear warmup over warmup_steps.")
|
help="Linear warmup over warmup_steps.")
|
||||||
parser.add_argument("--adam_epsilon",
|
parser.add_argument("--adam_epsilon",
|
||||||
default=1e-8,
|
default=1e-8,
|
||||||
type=float,
|
type=float,
|
||||||
help="Epsilon for Adam optimizer.")
|
help="Epsilon for Adam optimizer.")
|
||||||
parser.add_argument("--learning_rate",
|
parser.add_argument("--learning_rate",
|
||||||
@ -206,7 +207,7 @@ def main():
|
|||||||
|
|
||||||
if args.gradient_accumulation_steps < 1:
|
if args.gradient_accumulation_steps < 1:
|
||||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||||
args.gradient_accumulation_steps))
|
args.gradient_accumulation_steps))
|
||||||
|
|
||||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||||
|
|
||||||
@ -311,7 +312,7 @@ def main():
|
|||||||
outputs = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
outputs = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
if n_gpu > 1:
|
if n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu.
|
loss = loss.mean() # mean() to average on multi-gpu.
|
||||||
if args.gradient_accumulation_steps > 1:
|
if args.gradient_accumulation_steps > 1:
|
||||||
loss = loss / args.gradient_accumulation_steps
|
loss = loss / args.gradient_accumulation_steps
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
@ -336,8 +337,7 @@ def main():
|
|||||||
# Save a trained model
|
# Save a trained model
|
||||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||||
model_to_save = model.module if hasattr(model,
|
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
'module') else model # Take care of distributed/parallel training
|
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user