mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Update no trainer scripts for language modeling and image classification examples (#18443)
* Update no_trainer script for image-classification * Update no_trainer scripts for language-modeling examples * Remove unused variable * Removing truncation from losses array for language modeling examples
This commit is contained in:
parent
10e1ec9a8c
commit
3db4378bd7
@ -508,19 +508,11 @@ def main():
|
||||
break
|
||||
|
||||
model.eval()
|
||||
samples_seen = 0
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
predictions, references = accelerator.gather((predictions, batch["labels"]))
|
||||
# If we are in a multiprocess environment, the last batch has duplicates
|
||||
if accelerator.num_processes > 1:
|
||||
if step == len(eval_dataloader) - 1:
|
||||
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
|
||||
references = references[: len(eval_dataloader.dataset) - samples_seen]
|
||||
else:
|
||||
samples_seen += references.shape[0]
|
||||
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=references,
|
||||
|
@ -597,10 +597,9 @@ def main():
|
||||
outputs = model(**batch)
|
||||
|
||||
loss = outputs.loss
|
||||
losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
|
||||
losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))
|
||||
|
||||
losses = torch.cat(losses)
|
||||
losses = losses[: len(eval_dataset)]
|
||||
try:
|
||||
eval_loss = torch.mean(losses)
|
||||
perplexity = math.exp(eval_loss)
|
||||
|
@ -642,10 +642,9 @@ def main():
|
||||
outputs = model(**batch)
|
||||
|
||||
loss = outputs.loss
|
||||
losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
|
||||
losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))
|
||||
|
||||
losses = torch.cat(losses)
|
||||
losses = losses[: len(eval_dataset)]
|
||||
try:
|
||||
eval_loss = torch.mean(losses)
|
||||
perplexity = math.exp(eval_loss)
|
||||
|
Loading…
Reference in New Issue
Block a user