mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Cleanup TPU bits from run_glue.py
TPU runner is currently implemented in: https://github.com/pytorch-tpu/transformers/blob/tpu/examples/run_glue_tpu.py. We plan to upstream this directly into `huggingface/transformers` (either `master` or `tpu`) branch once it's been more thoroughly tested.
This commit is contained in:
parent
454455c695
commit
e70cdf083d
@ -158,7 +158,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
@ -189,11 +189,6 @@ def train(args, train_dataset, model, tokenizer):
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.tpu:
|
||||
args.xla_model.optimizer_step(optimizer, barrier=True)
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
@ -397,15 +392,6 @@ def main():
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--tpu', action='store_true',
|
||||
help="Whether to run on the TPU defined in the environment variables")
|
||||
parser.add_argument('--tpu_ip_address', type=str, default='',
|
||||
help="TPU IP address if none are set in the environment variables")
|
||||
parser.add_argument('--tpu_name', type=str, default='',
|
||||
help="TPU name if none are set in the environment variables")
|
||||
parser.add_argument('--xrt_tpu_config', type=str, default='',
|
||||
help="XRT TPU config if none are set in the environment variables")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
@ -439,23 +425,6 @@ def main():
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
if args.tpu:
|
||||
if args.tpu_ip_address:
|
||||
os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address
|
||||
if args.tpu_name:
|
||||
os.environ["TPU_NAME"] = args.tpu_name
|
||||
if args.xrt_tpu_config:
|
||||
os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config
|
||||
|
||||
assert "TPU_IP_ADDRESS" in os.environ
|
||||
assert "TPU_NAME" in os.environ
|
||||
assert "XRT_TPU_CONFIG" in os.environ
|
||||
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
args.device = xm.xla_device()
|
||||
args.xla_model = xm
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
@ -509,7 +478,7 @@ def main():
|
||||
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0) and not args.tpu:
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
Loading…
Reference in New Issue
Block a user