mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix flax examples tests (#14646)
* make tensorboard optional * update test_fetcher for flax examples * make the tests slow
This commit is contained in:
parent
03fda7b743
commit
75ae287aec
@ -40,7 +40,6 @@ import optax
|
||||
import transformers
|
||||
from flax import struct, traverse_util
|
||||
from flax.jax_utils import replicate, unreplicate
|
||||
from flax.metrics import tensorboard
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository
|
||||
@ -52,6 +51,7 @@ from transformers import (
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerFast,
|
||||
TrainingArguments,
|
||||
is_tensorboard_available,
|
||||
)
|
||||
from transformers.file_utils import get_full_repo_name
|
||||
from transformers.utils import check_min_version
|
||||
@ -716,8 +716,23 @@ def main():
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# Define a summary writer
|
||||
summary_writer = tensorboard.SummaryWriter(training_args.output_dir)
|
||||
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
|
||||
summary_writer = SummaryWriter(training_args.output_dir)
|
||||
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
logger.warning(
|
||||
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unable to display metrics through TensorBoard because the package is not installed: "
|
||||
"Please run pip install tensorboard to enable."
|
||||
)
|
||||
|
||||
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
@ -833,7 +848,7 @@ def main():
|
||||
# Save metrics
|
||||
train_metric = unreplicate(train_metric)
|
||||
train_time += time.time() - train_start
|
||||
if jax.process_index() == 0:
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||
|
||||
epochs.write(
|
||||
@ -898,7 +913,7 @@ def main():
|
||||
|
||||
logger.info(f"Step... ({cur_step}/{total_steps} | Evaluation metrics: {eval_metrics})")
|
||||
|
||||
if jax.process_index() == 0:
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
||||
|
||||
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
|
||||
|
@ -96,6 +96,7 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
|
||||
@slow
|
||||
def test_run_clm(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
@ -155,6 +156,7 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["test_rougeL"], 7)
|
||||
self.assertGreaterEqual(result["test_rougeLsum"], 7)
|
||||
|
||||
@slow
|
||||
def test_run_mlm(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
@ -208,6 +210,7 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.42)
|
||||
|
||||
@slow
|
||||
def test_run_ner(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
@ -240,6 +243,7 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
self.assertGreaterEqual(result["eval_f1"], 0.3)
|
||||
|
||||
@slow
|
||||
def test_run_qa(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
@ -33,11 +33,16 @@ import optax
|
||||
import transformers
|
||||
from flax import struct, traverse_util
|
||||
from flax.jax_utils import replicate, unreplicate
|
||||
from flax.metrics import tensorboard
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository
|
||||
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
FlaxAutoModelForSequenceClassification,
|
||||
PretrainedConfig,
|
||||
is_tensorboard_available,
|
||||
)
|
||||
from transformers.file_utils import get_full_repo_name
|
||||
|
||||
|
||||
@ -404,8 +409,23 @@ def main():
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# Define a summary writer
|
||||
summary_writer = tensorboard.SummaryWriter(args.output_dir)
|
||||
summary_writer.hparams(vars(args))
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
|
||||
summary_writer = SummaryWriter(args.output_dir)
|
||||
summary_writer.hparams(vars(args))
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
logger.warning(
|
||||
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unable to display metrics through TensorBoard because the package is not installed: "
|
||||
"Please run pip install tensorboard to enable."
|
||||
)
|
||||
|
||||
def write_metric(train_metrics, eval_metrics, train_time, step):
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
@ -513,7 +533,10 @@ def main():
|
||||
logger.info(f" Done! Eval metrics: {eval_metric}")
|
||||
|
||||
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
||||
write_metric(train_metrics, eval_metric, train_time, cur_step)
|
||||
|
||||
# Save metrics
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
write_metric(train_metrics, eval_metric, train_time, cur_step)
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
|
@ -36,7 +36,6 @@ import optax
|
||||
import transformers
|
||||
from flax import struct, traverse_util
|
||||
from flax.jax_utils import replicate, unreplicate
|
||||
from flax.metrics import tensorboard
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository
|
||||
@ -46,6 +45,7 @@ from transformers import (
|
||||
FlaxAutoModelForTokenClassification,
|
||||
HfArgumentParser,
|
||||
TrainingArguments,
|
||||
is_tensorboard_available,
|
||||
)
|
||||
from transformers.file_utils import get_full_repo_name
|
||||
from transformers.utils import check_min_version
|
||||
@ -472,8 +472,23 @@ def main():
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# Define a summary writer
|
||||
summary_writer = tensorboard.SummaryWriter(training_args.output_dir)
|
||||
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
|
||||
summary_writer = SummaryWriter(training_args.output_dir)
|
||||
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
logger.warning(
|
||||
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unable to display metrics through TensorBoard because the package is not installed: "
|
||||
"Please run pip install tensorboard to enable."
|
||||
)
|
||||
|
||||
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
@ -605,7 +620,7 @@ def main():
|
||||
# Save metrics
|
||||
train_metric = unreplicate(train_metric)
|
||||
train_time += time.time() - train_start
|
||||
if jax.process_index() == 0:
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||
|
||||
epochs.write(
|
||||
@ -663,7 +678,7 @@ def main():
|
||||
f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})"
|
||||
)
|
||||
|
||||
if jax.process_index() == 0:
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
||||
|
||||
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
|
||||
|
@ -431,6 +431,8 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
|
||||
# Example files are tested separately
|
||||
elif f.startswith("examples/pytorch"):
|
||||
test_files_to_run.append("examples/pytorch/test_examples.py")
|
||||
elif f.startswith("examples/flax"):
|
||||
test_files_to_run.append("examples/flax/test_examples.py")
|
||||
else:
|
||||
new_tests = module_to_test_file(f)
|
||||
if new_tests is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user