mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[tests] fix mamba integration simple inference precision issue (#37193)
* fix precision issue * use float32
This commit is contained in:
parent
6ce238fe7a
commit
a0803a9555
@ -451,7 +451,7 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16)
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float32)
|
||||
model.to(device)
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
|
||||
|
||||
@ -464,14 +464,13 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
|
||||
[
|
||||
-55.6875, -69.8750, -49.9062, -51.7500, -57.6875, -57.9375, -56.9688,
|
||||
-57.9375, -54.6875, -55.9375, -55.3125, -58.0938, -60.5625, -47.0000,
|
||||
-52.0312, -49.7812, -55.9375, -57.9062, -56.7812, -57.1250, -57.3438,
|
||||
-58.3125, -57.8125, -58.7812, -59.6250, -59.0938, -58.7188, -52.9375,
|
||||
-53.4688, -57.3750, -56.9375, -55.7500, -53.3125, -55.8438, -57.0000,
|
||||
-56.9062, -56.2188, -54.7188, -56.4375, -57.5000
|
||||
]
|
||||
,dtype=torch.float32) # fmt: skip
|
||||
-55.6909, -69.7903, -49.8981, -51.7581, -57.6544, -57.9368, -56.9591,
|
||||
-57.9033, -54.6787, -55.9261, -55.3011, -58.0765, -60.5642, -47.0176,
|
||||
-52.0344, -49.7836, -55.9463, -57.8957, -56.7627, -57.1080, -57.3434,
|
||||
-58.3015, -57.7875, -58.7760, -59.6037, -59.0665, -58.7087, -52.9293,
|
||||
-53.4654, -57.3466, -56.9294, -55.7314, -53.3141, -55.8171, -56.9879,
|
||||
-56.9121, -56.2139, -54.7198, -56.4134, -57.4825
|
||||
]) # fmt: skip
|
||||
|
||||
torch.testing.assert_close(logits[0, 0, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user