mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
wrap forward passes with torch.no_grad() (#19416)
This commit is contained in:
parent
870a9542be
commit
d739a707d9
@ -570,6 +570,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
|
||||
table, queries = prepare_tapas_single_inputs_for_inference()
|
||||
inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
|
||||
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
# test the sequence output
|
||||
expected_slice = torch.tensor(
|
||||
@ -608,6 +609,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
|
||||
table, queries = prepare_tapas_single_inputs_for_inference()
|
||||
inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
|
||||
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
# test the logits
|
||||
logits = outputs.logits
|
||||
@ -657,6 +659,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
|
||||
table, queries = prepare_tapas_single_inputs_for_inference()
|
||||
inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
|
||||
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
# test the logits
|
||||
logits = outputs.logits
|
||||
@ -705,6 +708,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
|
||||
inputs = tokenizer(table=table, queries=queries, padding="longest", return_tensors="pt")
|
||||
inputs_on_device = {k: v.to(torch_device) for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_on_device)
|
||||
# test the logits
|
||||
logits = outputs.logits
|
||||
@ -774,6 +778,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
|
||||
float_answer = torch.FloatTensor(float_answer).to(torch_device)
|
||||
|
||||
# forward pass to get loss + logits:
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -829,6 +834,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
|
||||
table, queries = prepare_tapas_single_inputs_for_inference()
|
||||
inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
|
||||
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
# test the logits
|
||||
logits = outputs.logits
|
||||
@ -884,6 +890,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
|
||||
table, queries = prepare_tapas_single_inputs_for_inference()
|
||||
inputs = tokenizer(table=table, queries=queries, padding="longest", return_tensors="pt")
|
||||
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# test the classification logits
|
||||
|
Loading…
Reference in New Issue
Block a user