mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
add distilbert + update run_xnli wrt run_glue
This commit is contained in:
parent
07ab8d7af6
commit
d5478b939d
@ -13,7 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Finetuning multi-lingual models on XNLI (Bert, XLM).
|
""" Finetuning multi-lingual models on XNLI (Bert, DistilBERT, XLM).
|
||||||
Adapted from `examples/run_glue.py`"""
|
Adapted from `examples/run_glue.py`"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
@ -42,7 +42,7 @@ from transformers import (WEIGHTS_NAME,
|
|||||||
XLMConfig, XLMForSequenceClassification, XLMTokenizer,
|
XLMConfig, XLMForSequenceClassification, XLMTokenizer,
|
||||||
DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
||||||
|
|
||||||
from transformers import AdamW, WarmupLinearSchedule
|
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
from transformers import xnli_compute_metrics as compute_metrics
|
from transformers import xnli_compute_metrics as compute_metrics
|
||||||
from transformers import xnli_output_modes as output_modes
|
from transformers import xnli_output_modes as output_modes
|
||||||
@ -52,12 +52,12 @@ from transformers import glue_convert_examples_to_features as convert_examples_t
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLMConfig)), ())
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ())
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||||
# 'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@ -149,7 +149,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
@ -180,11 +180,6 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
if args.tpu:
|
|
||||||
args.xla_model.optimizer_step(optimizer, barrier=True)
|
|
||||||
model.zero_grad()
|
|
||||||
global_step += 1
|
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
epoch_iterator.close()
|
epoch_iterator.close()
|
||||||
break
|
break
|
||||||
@ -214,6 +209,10 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
|
# multi-gpu eval
|
||||||
|
if args.n_gpu > 1:
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
# Eval!
|
# Eval!
|
||||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||||
logger.info(" Num examples = %d", len(eval_dataset))
|
logger.info(" Num examples = %d", len(eval_dataset))
|
||||||
@ -383,15 +382,6 @@ def main():
|
|||||||
parser.add_argument('--seed', type=int, default=42,
|
parser.add_argument('--seed', type=int, default=42,
|
||||||
help="random seed for initialization")
|
help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument('--tpu', action='store_true',
|
|
||||||
help="Whether to run on the TPU defined in the environment variables")
|
|
||||||
parser.add_argument('--tpu_ip_address', type=str, default='',
|
|
||||||
help="TPU IP address if none are set in the environment variables")
|
|
||||||
parser.add_argument('--tpu_name', type=str, default='',
|
|
||||||
help="TPU name if none are set in the environment variables")
|
|
||||||
parser.add_argument('--xrt_tpu_config', type=str, default='',
|
|
||||||
help="XRT TPU config if none are set in the environment variables")
|
|
||||||
|
|
||||||
parser.add_argument('--fp16', action='store_true',
|
parser.add_argument('--fp16', action='store_true',
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||||
@ -425,23 +415,6 @@ def main():
|
|||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
if args.tpu:
|
|
||||||
if args.tpu_ip_address:
|
|
||||||
os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address
|
|
||||||
if args.tpu_name:
|
|
||||||
os.environ["TPU_NAME"] = args.tpu_name
|
|
||||||
if args.xrt_tpu_config:
|
|
||||||
os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config
|
|
||||||
|
|
||||||
assert "TPU_IP_ADDRESS" in os.environ
|
|
||||||
assert "TPU_NAME" in os.environ
|
|
||||||
assert "XRT_TPU_CONFIG" in os.environ
|
|
||||||
|
|
||||||
import torch_xla
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
args.device = xm.xla_device()
|
|
||||||
args.xla_model = xm
|
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@ -495,7 +468,7 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0) and not args.tpu:
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
@ -512,7 +485,7 @@ def main():
|
|||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user