mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Improve pytorch examples for fp16 (#9796)
* Pad to 8x for fp16 multiple choice example (#9752) * Pad to 8x for fp16 squad trainer example (#9752) * Pad to 8x for fp16 ner example (#9752) * Pad to 8x for fp16 swag example (#9752) * Pad to 8x for fp16 qa beam search example (#9752) * Pad to 8x for fp16 qa example (#9752) * Pad to 8x for fp16 seq2seq example (#9752) * Pad to 8x for fp16 glue example (#9752) * Pad to 8x for fp16 new ner example (#9752) * update script template #9752 * Update examples/multiple-choice/run_swag.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/question-answering/run_qa.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/question-answering/run_qa_beam_search.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * improve code quality #9752 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
781e4b1384
commit
10e5f28212
@ -28,6 +28,7 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForMultipleChoice,
|
AutoModelForMultipleChoice,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
DataCollatorWithPadding,
|
||||||
EvalPrediction,
|
EvalPrediction,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
Trainer,
|
Trainer,
|
||||||
@ -188,6 +189,9 @@ def main():
|
|||||||
preds = np.argmax(p.predictions, axis=1)
|
preds = np.argmax(p.predictions, axis=1)
|
||||||
return {"acc": simple_accuracy(preds, p.label_ids)}
|
return {"acc": simple_accuracy(preds, p.label_ids)}
|
||||||
|
|
||||||
|
# Data collator
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) if training_args.fp16 else None
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
@ -195,6 +199,7 @@ def main():
|
|||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=compute_metrics,
|
||||||
|
data_collator=data_collator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
@ -23,7 +23,14 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer, HfArgumentParser, SquadDataset
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForQuestionAnswering,
|
||||||
|
AutoTokenizer,
|
||||||
|
DataCollatorWithPadding,
|
||||||
|
HfArgumentParser,
|
||||||
|
SquadDataset,
|
||||||
|
)
|
||||||
from transformers import SquadDataTrainingArguments as DataTrainingArguments
|
from transformers import SquadDataTrainingArguments as DataTrainingArguments
|
||||||
from transformers import Trainer, TrainingArguments
|
from transformers import Trainer, TrainingArguments
|
||||||
from transformers.trainer_utils import is_main_process
|
from transformers.trainer_utils import is_main_process
|
||||||
@ -145,12 +152,16 @@ def main():
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Data collator
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) if training_args.fp16 else None
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
|
data_collator=data_collator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
@ -30,6 +30,7 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
DataCollatorWithPadding,
|
||||||
EvalPrediction,
|
EvalPrediction,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
Trainer,
|
Trainer,
|
||||||
@ -237,6 +238,9 @@ def main():
|
|||||||
"f1": f1_score(out_label_list, preds_list),
|
"f1": f1_score(out_label_list, preds_list),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Data collator
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) if training_args.fp16 else None
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
@ -244,6 +248,7 @@ def main():
|
|||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=compute_metrics,
|
||||||
|
data_collator=data_collator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
@ -316,7 +316,9 @@ def main():
|
|||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
data_collator = (
|
data_collator = (
|
||||||
default_data_collator if data_args.pad_to_max_length else DataCollatorForMultipleChoice(tokenizer=tokenizer)
|
default_data_collator
|
||||||
|
if data_args.pad_to_max_length
|
||||||
|
else DataCollatorForMultipleChoice(tokenizer=tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Metric
|
# Metric
|
||||||
|
@ -411,7 +411,11 @@ def main():
|
|||||||
# Data collator
|
# Data collator
|
||||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||||
# collator.
|
# collator.
|
||||||
data_collator = default_data_collator if data_args.pad_to_max_length else DataCollatorWithPadding(tokenizer)
|
data_collator = (
|
||||||
|
default_data_collator
|
||||||
|
if data_args.pad_to_max_length
|
||||||
|
else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||||
|
)
|
||||||
|
|
||||||
# Post-processing:
|
# Post-processing:
|
||||||
def post_processing_function(examples, features, predictions):
|
def post_processing_function(examples, features, predictions):
|
||||||
|
@ -448,7 +448,11 @@ def main():
|
|||||||
# Data collator
|
# Data collator
|
||||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||||
# collator.
|
# collator.
|
||||||
data_collator = default_data_collator if data_args.pad_to_max_length else DataCollatorWithPadding(tokenizer)
|
data_collator = (
|
||||||
|
default_data_collator
|
||||||
|
if data_args.pad_to_max_length
|
||||||
|
else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||||
|
)
|
||||||
|
|
||||||
# Post-processing:
|
# Post-processing:
|
||||||
def post_processing_function(examples, features, predictions):
|
def post_processing_function(examples, features, predictions):
|
||||||
|
@ -437,7 +437,11 @@ def main():
|
|||||||
if data_args.pad_to_max_length:
|
if data_args.pad_to_max_length:
|
||||||
data_collator = default_data_collator
|
data_collator = default_data_collator
|
||||||
else:
|
else:
|
||||||
data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=label_pad_token_id)
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
|
tokenizer,
|
||||||
|
label_pad_token_id=label_pad_token_id,
|
||||||
|
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||||
|
)
|
||||||
|
|
||||||
# Metric
|
# Metric
|
||||||
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
|
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
|
||||||
|
@ -30,6 +30,7 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
DataCollatorWithPadding,
|
||||||
EvalPrediction,
|
EvalPrediction,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
@ -375,6 +376,14 @@ def main():
|
|||||||
else:
|
else:
|
||||||
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
|
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
|
||||||
|
|
||||||
|
# Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
|
||||||
|
if data_args.pad_to_max_length:
|
||||||
|
data_collator = default_data_collator
|
||||||
|
elif training_args.fp16:
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
|
else:
|
||||||
|
data_collator = None
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
@ -383,8 +392,7 @@ def main():
|
|||||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=compute_metrics,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
# Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
|
data_collator=data_collator,
|
||||||
data_collator=default_data_collator if data_args.pad_to_max_length else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
@ -327,7 +327,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
data_collator = DataCollatorForTokenClassification(tokenizer)
|
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
metric = load_metric("seqeval")
|
metric = load_metric("seqeval")
|
||||||
|
@ -33,6 +33,7 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
{{cookiecutter.model_class}},
|
{{cookiecutter.model_class}},
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
DataCollatorWithPadding,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@ -323,7 +324,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
data_collator=default_data_collator
|
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
Loading…
Reference in New Issue
Block a user