mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
wrap forward passes with torch.no_grad() (#19438)
This commit is contained in:
parent
692c5be74e
commit
870a9542be
@ -457,7 +457,8 @@ class RoFormerModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_masked_lm(self):
|
||||
model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
# TODO Replace vocab size
|
||||
vocab_size = 50000
|
||||
|
Loading…
Reference in New Issue
Block a user