mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
OLMo 7B Twin 2T integration test fix
This commit is contained in:
parent
0b227bb2af
commit
9ff65a4a29
@ -384,10 +384,10 @@ class OLMoIntegrationTest(unittest.TestCase):
|
||||
model = OLMoForCausalLM.from_pretrained("allenai/OLMo-7B-Twin-2T-hf")
|
||||
out = model(torch.tensor(input_ids)).logits
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-0.3636, -0.3825, -0.4800, -0.3696, -0.8388, -0.9737, -0.9849, -0.8356]])
|
||||
EXPECTED_MEAN = torch.tensor([[-0.4083, -0.5103, -0.2002, -0.3794, -0.6669, -0.8921, -0.9318, -0.7040]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
EXPECTED_SLICE = torch.tensor([-2.0833, -1.9234, 8.7312, 7.8049, 1.0372, 0.8941, 3.1548, 1.8502, 5.5511, 5.5793, 8.1166, 4.5906, 1.8691, 11.6377, 8.9858, 11.6447, 7.4549, 1.4725, 2.8399, 2.7568, 1.4011, 1.6958, 0.5572, 0.5231, 0.3068, 0.5364, 0.6769, 7.9636, 8.2379, 1.7950]) # fmt: skip
|
||||
EXPECTED_SLICE = torch.tensor([-1.9491, -1.9579, 9.4918, 7.8442, 2.4652, 2.4859, 1.8890, 3.2533, 6.4323, 4.8175, 7.3225, 3.4401, 2.3375, 11.8406, 8.2725, 12.4737, 6.5298, 1.6334, 2.7099, 2.7729, 1.8560, 1.6340, 0.6442, 0.5034, 0.0823, 0.5256, 0.6036, 6.3053, 7.9086, 3.2711]) # fmt: skip
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@unittest.skip("Model is curently gated")
|
||||
|
Loading…
Reference in New Issue
Block a user