mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[lm examples] fix overflow in perplexity calc (#11855)
* fix overflow in perplexity calc * use inf * fix
This commit is contained in:
parent
7630c11f32
commit
6287c929c1
@ -440,7 +440,10 @@ def main():
|
||||
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
|
@ -442,7 +442,10 @@ def main():
|
||||
|
||||
losses = torch.cat(losses)
|
||||
losses = losses[: len(eval_dataset)]
|
||||
perplexity = math.exp(torch.mean(losses))
|
||||
try:
|
||||
perplexity = math.exp(torch.mean(losses))
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
|
||||
logger.info(f"epoch {epoch}: perplexity: {perplexity}")
|
||||
|
||||
|
@ -469,7 +469,10 @@ def main():
|
||||
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
|
@ -486,7 +486,10 @@ def main():
|
||||
|
||||
losses = torch.cat(losses)
|
||||
losses = losses[: len(eval_dataset)]
|
||||
perplexity = math.exp(torch.mean(losses))
|
||||
try:
|
||||
perplexity = math.exp(torch.mean(losses))
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
|
||||
logger.info(f"epoch {epoch}: perplexity: {perplexity}")
|
||||
|
||||
|
@ -445,7 +445,10 @@ def main():
|
||||
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
|
Loading…
Reference in New Issue
Block a user