mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add torch.no_grad when in eval mode (#17020)
* add torch.no_grad when in eval mode * make style quality
This commit is contained in:
parent
9586e222af
commit
bdd690a74d
@ -469,7 +469,8 @@ def main():
|
||||
model.eval()
|
||||
samples_seen = 0
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
outputs = model(**batch)
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
predictions, references = accelerator.gather((predictions, batch["labels"]))
|
||||
# If we are in a multiprocess environment, the last batch has duplicates
|
||||
|
@ -579,7 +579,8 @@ def main():
|
||||
model.eval()
|
||||
samples_seen = 0
|
||||
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
|
||||
outputs = model(**batch)
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
|
||||
upsampled_logits = torch.nn.functional.interpolate(
|
||||
outputs.logits, size=batch["labels"].shape[-2:], mode="bilinear", align_corners=False
|
||||
|
@ -22,6 +22,7 @@ import random
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
@ -514,7 +515,8 @@ def main():
|
||||
model.eval()
|
||||
samples_seen = 0
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
outputs = model(**batch)
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
||||
predictions, references = accelerator.gather((predictions, batch["labels"]))
|
||||
# If we are in a multiprocess environment, the last batch has duplicates
|
||||
|
@ -28,6 +28,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
import transformers
|
||||
@ -871,7 +872,8 @@ def main():
|
||||
|
||||
model.eval()
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
outputs = model(**batch)
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
metric.add_batch(
|
||||
predictions=accelerator.gather(predictions),
|
||||
|
Loading…
Reference in New Issue
Block a user