wrap forward passes with torch.no_grad() (#19438)

This commit is contained in:
Partho 2022-10-11 00:24:54 +05:30 committed by GitHub
parent 692c5be74e
commit 870a9542be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -457,7 +457,8 @@ class RoFormerModelIntegrationTest(unittest.TestCase):
def test_inference_masked_lm(self):
model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
with torch.no_grad():
output = model(input_ids)[0]
# TODO Replace vocab size
vocab_size = 50000