mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[lm examples] fix overflow in perplexity calc (#11855)
* fix overflow in perplexity calc * use inf * fix
This commit is contained in:
parent
7630c11f32
commit
6287c929c1
@ -440,7 +440,10 @@ def main():
|
|||||||
|
|
||||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
try:
|
||||||
perplexity = math.exp(metrics["eval_loss"])
|
perplexity = math.exp(metrics["eval_loss"])
|
||||||
|
except OverflowError:
|
||||||
|
perplexity = float("inf")
|
||||||
metrics["perplexity"] = perplexity
|
metrics["perplexity"] = perplexity
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
|
@ -442,7 +442,10 @@ def main():
|
|||||||
|
|
||||||
losses = torch.cat(losses)
|
losses = torch.cat(losses)
|
||||||
losses = losses[: len(eval_dataset)]
|
losses = losses[: len(eval_dataset)]
|
||||||
|
try:
|
||||||
perplexity = math.exp(torch.mean(losses))
|
perplexity = math.exp(torch.mean(losses))
|
||||||
|
except OverflowError:
|
||||||
|
perplexity = float("inf")
|
||||||
|
|
||||||
logger.info(f"epoch {epoch}: perplexity: {perplexity}")
|
logger.info(f"epoch {epoch}: perplexity: {perplexity}")
|
||||||
|
|
||||||
|
@ -469,7 +469,10 @@ def main():
|
|||||||
|
|
||||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
try:
|
||||||
perplexity = math.exp(metrics["eval_loss"])
|
perplexity = math.exp(metrics["eval_loss"])
|
||||||
|
except OverflowError:
|
||||||
|
perplexity = float("inf")
|
||||||
metrics["perplexity"] = perplexity
|
metrics["perplexity"] = perplexity
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
|
@ -486,7 +486,10 @@ def main():
|
|||||||
|
|
||||||
losses = torch.cat(losses)
|
losses = torch.cat(losses)
|
||||||
losses = losses[: len(eval_dataset)]
|
losses = losses[: len(eval_dataset)]
|
||||||
|
try:
|
||||||
perplexity = math.exp(torch.mean(losses))
|
perplexity = math.exp(torch.mean(losses))
|
||||||
|
except OverflowError:
|
||||||
|
perplexity = float("inf")
|
||||||
|
|
||||||
logger.info(f"epoch {epoch}: perplexity: {perplexity}")
|
logger.info(f"epoch {epoch}: perplexity: {perplexity}")
|
||||||
|
|
||||||
|
@ -445,7 +445,10 @@ def main():
|
|||||||
|
|
||||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
try:
|
||||||
perplexity = math.exp(metrics["eval_loss"])
|
perplexity = math.exp(metrics["eval_loss"])
|
||||||
|
except OverflowError:
|
||||||
|
perplexity = float("inf")
|
||||||
metrics["perplexity"] = perplexity
|
metrics["perplexity"] = perplexity
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
|
Loading…
Reference in New Issue
Block a user