mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
token-classification: use is_world_process_zero instead of deprecated is_world_master() (#8828)
This commit is contained in:
parent
40ecaf0c2b
commit
19fa01ce2a
@ -369,7 +369,7 @@ def main():
|
||||
]
|
||||
|
||||
output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
|
||||
if trainer.is_world_master():
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key, value in metrics.items():
|
||||
logger.info(f" {key} = {value}")
|
||||
@ -377,7 +377,7 @@ def main():
|
||||
|
||||
# Save predictions
|
||||
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
|
||||
if trainer.is_world_master():
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_test_predictions_file, "w") as writer:
|
||||
for prediction in true_predictions:
|
||||
writer.write(" ".join(prediction) + "\n")
|
||||
|
@ -291,7 +291,7 @@ def main():
|
||||
preds_list, _ = align_predictions(predictions, label_ids)
|
||||
|
||||
output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
|
||||
if trainer.is_world_master():
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key, value in metrics.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
@ -299,7 +299,7 @@ def main():
|
||||
|
||||
# Save predictions
|
||||
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
|
||||
if trainer.is_world_master():
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_test_predictions_file, "w") as writer:
|
||||
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
|
||||
token_classification_task.write_predictions_to_file(writer, f, preds_list)
|
||||
|
Loading…
Reference in New Issue
Block a user