mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix args.gradient_accumulation_steps
used before assigment.
This commit is contained in:
parent
649e9774cd
commit
290633b882
@ -404,6 +404,10 @@ def main():
|
||||
type=int,
|
||||
default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--gradient_accumulation_steps',
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumualte before performing a backward/update pass.")
|
||||
args = parser.parse_args()
|
||||
|
||||
processors = {
|
||||
@ -469,7 +473,7 @@ def main():
|
||||
|
||||
model = BertForSequenceClassification(bert_config, len(label_list))
|
||||
if args.init_checkpoint is not None:
|
||||
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.to(device)
|
||||
|
||||
if args.local_rank != -1:
|
||||
|
@ -739,6 +739,10 @@ def main():
|
||||
type=int,
|
||||
default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--gradient_accumulation_steps',
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumualte before performing a backward/update pass.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user