mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
wrap forward passes with torch.no_grad() (#19438)
This commit is contained in:
parent
692c5be74e
commit
870a9542be
@ -457,7 +457,8 @@ class RoFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
def test_inference_masked_lm(self):
|
def test_inference_masked_lm(self):
|
||||||
model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
|
model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
|
||||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||||
output = model(input_ids)[0]
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)[0]
|
||||||
|
|
||||||
# TODO Replace vocab size
|
# TODO Replace vocab size
|
||||||
vocab_size = 50000
|
vocab_size = 50000
|
||||||
|
Loading…
Reference in New Issue
Block a user