mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
PEP8 and formatting cleanups
This commit is contained in:
parent
1798e98e5a
commit
7de5c6aa5e
@ -9,7 +9,7 @@ from collections import namedtuple
|
||||
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
from tqdm import tqdm
|
||||
|
||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
@ -149,7 +149,8 @@ def main():
|
||||
help="random seed for initialization")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.pregenerated_data.is_dir(), "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"
|
||||
assert args.pregenerated_data.is_dir(), \
|
||||
"--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"
|
||||
|
||||
samples_per_epoch = []
|
||||
for i in range(args.epochs):
|
||||
@ -237,7 +238,8 @@ def main():
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
raise ImportError(
|
||||
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
@ -293,7 +295,8 @@ def main():
|
||||
if args.fp16:
|
||||
# modify learning rate with special warm up BERT uses
|
||||
# if args.fp16 is False, BertAdam is used that handles this automatically
|
||||
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
|
||||
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps,
|
||||
args.warmup_proportion)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr_this_step
|
||||
optimizer.step()
|
||||
|
@ -269,6 +269,5 @@ def main():
|
||||
metrics_file.write(json.dumps(metrics))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user