mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix args.gradient_accumulation_steps
used before assigment.
This commit is contained in:
parent
649e9774cd
commit
290633b882
@ -404,6 +404,10 @@ def main():
|
|||||||
type=int,
|
type=int,
|
||||||
default=42,
|
default=42,
|
||||||
help="random seed for initialization")
|
help="random seed for initialization")
|
||||||
|
parser.add_argument('--gradient_accumulation_steps',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumualte before performing a backward/update pass.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
processors = {
|
processors = {
|
||||||
@ -469,7 +473,7 @@ def main():
|
|||||||
|
|
||||||
model = BertForSequenceClassification(bert_config, len(label_list))
|
model = BertForSequenceClassification(bert_config, len(label_list))
|
||||||
if args.init_checkpoint is not None:
|
if args.init_checkpoint is not None:
|
||||||
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
|
@ -739,7 +739,11 @@ def main():
|
|||||||
type=int,
|
type=int,
|
||||||
default=42,
|
default=42,
|
||||||
help="random seed for initialization")
|
help="random seed for initialization")
|
||||||
|
parser.add_argument('--gradient_accumulation_steps',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumualte before performing a backward/update pass.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.local_rank == -1 or args.no_cuda:
|
if args.local_rank == -1 or args.no_cuda:
|
||||||
|
Loading…
Reference in New Issue
Block a user