mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
clean up logger in examples for distributed case
This commit is contained in:
parent
cc43307023
commit
1135f2384a
16
README.md
16
README.md
@ -1274,18 +1274,20 @@ To get these results we used a combination of:
|
||||
|
||||
Here is the full list of hyper-parameters for this run:
|
||||
```bash
|
||||
export SQUAD_DIR=/path/to/SQUAD
|
||||
|
||||
python ./run_squad.py \
|
||||
--bert_model bert-large-uncased \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--do_lower_case \
|
||||
--train_file $SQUAD_TRAIN \
|
||||
--predict_file $SQUAD_EVAL \
|
||||
--train_file $SQUAD_DIR/train-v1.1.json \
|
||||
--predict_file $SQUAD_DIR/dev-v1.1.json \
|
||||
--learning_rate 3e-5 \
|
||||
--num_train_epochs 2 \
|
||||
--max_seq_length 384 \
|
||||
--doc_stride 128 \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--output_dir /tmp/debug_squad/ \
|
||||
--train_batch_size 24 \
|
||||
--gradient_accumulation_steps 2
|
||||
```
|
||||
@ -1294,18 +1296,20 @@ If you have a recent GPU (starting from NVIDIA Volta series), you should try **1
|
||||
|
||||
Here is an example of hyper-parameters for a FP16 run we tried:
|
||||
```bash
|
||||
export SQUAD_DIR=/path/to/SQUAD
|
||||
|
||||
python ./run_squad.py \
|
||||
--bert_model bert-large-uncased \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--do_lower_case \
|
||||
--train_file $SQUAD_TRAIN \
|
||||
--predict_file $SQUAD_EVAL \
|
||||
--train_file $SQUAD_DIR/train-v1.1.json \
|
||||
--predict_file $SQUAD_DIR/dev-v1.1.json \
|
||||
--learning_rate 3e-5 \
|
||||
--num_train_epochs 2 \
|
||||
--max_seq_length 384 \
|
||||
--doc_stride 128 \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--output_dir /tmp/debug_squad/ \
|
||||
--train_batch_size 24 \
|
||||
--fp16 \
|
||||
--loss_scale 128
|
||||
|
@ -40,9 +40,6 @@ from pytorch_pretrained_bert.modeling import BertForSequenceClassification, Bert
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -697,6 +694,11 @@ def main():
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
|
@ -46,9 +46,6 @@ if sys.version_info[0] == 2:
|
||||
else:
|
||||
import pickle
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -848,6 +845,11 @@ def main():
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user