[lm examples] fix overflow in perplexity calc (#11855)

* fix overflow in perplexity calc

* use inf

* fix
This commit is contained in:
Stas Bekman 2021-05-25 08:11:26 -07:00 committed by GitHub
parent 7630c11f32
commit 6287c929c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 5 deletions

View File

@ -440,7 +440,10 @@ def main():
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"])
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)

View File

@ -442,7 +442,10 @@ def main():
losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
perplexity = math.exp(torch.mean(losses))
try:
perplexity = math.exp(torch.mean(losses))
except OverflowError:
perplexity = float("inf")
logger.info(f"epoch {epoch}: perplexity: {perplexity}")

View File

@ -469,7 +469,10 @@ def main():
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"])
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)

View File

@ -486,7 +486,10 @@ def main():
losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
perplexity = math.exp(torch.mean(losses))
try:
perplexity = math.exp(torch.mean(losses))
except OverflowError:
perplexity = float("inf")
logger.info(f"epoch {epoch}: perplexity: {perplexity}")

View File

@ -445,7 +445,10 @@ def main():
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"])
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)