diff --git a/examples/token-classification/run_ner.py b/examples/token-classification/run_ner.py index b6232bbed07..0021a719742 100644 --- a/examples/token-classification/run_ner.py +++ b/examples/token-classification/run_ner.py @@ -369,7 +369,7 @@ def main(): ] output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") - if trainer.is_world_master(): + if trainer.is_world_process_zero(): with open(output_test_results_file, "w") as writer: for key, value in metrics.items(): logger.info(f" {key} = {value}") @@ -377,7 +377,7 @@ def main(): # Save predictions output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") - if trainer.is_world_master(): + if trainer.is_world_process_zero(): with open(output_test_predictions_file, "w") as writer: for prediction in true_predictions: writer.write(" ".join(prediction) + "\n") diff --git a/examples/token-classification/run_ner_old.py b/examples/token-classification/run_ner_old.py index 7b1c808062f..035d100945c 100644 --- a/examples/token-classification/run_ner_old.py +++ b/examples/token-classification/run_ner_old.py @@ -291,7 +291,7 @@ def main(): preds_list, _ = align_predictions(predictions, label_ids) output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") - if trainer.is_world_master(): + if trainer.is_world_process_zero(): with open(output_test_results_file, "w") as writer: for key, value in metrics.items(): logger.info(" %s = %s", key, value) @@ -299,7 +299,7 @@ def main(): # Save predictions output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") - if trainer.is_world_master(): + if trainer.is_world_process_zero(): with open(output_test_predictions_file, "w") as writer: with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f: token_classification_task.write_predictions_to_file(writer, f, preds_list)