Fix : BitNet tests (#34895)

* fix_tests_bitnet

* fix format
This commit is contained in:
Mohamed Mekkouri 2024-11-25 16:47:14 +01:00 committed by GitHub
parent 9121ab8fe8
commit 4e6b19cd95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -95,16 +95,16 @@ class BitNetTest(unittest.TestCase):
self.assertEqual(nb_linears - 1, nb_bitnet_linear)
def test_quantized_model(self, quantized_model, tokenizer):
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_text = "What are we having for dinner?"
expected_output = "What are we having for dinner? What are we going to do for fun this weekend?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
input_ids = self.tokenizer(input_text, return_tensors="pt").to("cuda")
output = quantized_model.generate(**input_ids, max_new_tokens=11, do_sample=False)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
output = self.quantized_model.generate(**input_ids, max_new_tokens=11, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
def test_packing_unpacking(self):
"""
@ -113,9 +113,12 @@ class BitNetTest(unittest.TestCase):
from transformers.integrations import pack_weights, unpack_weights
u = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8)
u = torch.randint(0, 255, (256, 256), dtype=torch.uint8)
unpacked_u = unpack_weights(u, dtype=torch.bfloat16)
self.assertEqual(pack_weights(unpacked_u), u)
repacked_u = pack_weights(unpacked_u)
for i in range(u.shape[0]):
for j in range(u.shape[1]):
self.assertEqual(repacked_u[i][j], u[i][j])
def test_activation_quant(self):
"""
@ -127,15 +130,14 @@ class BitNetTest(unittest.TestCase):
layer = BitLinear(in_features=4, out_features=2, bias=False, dtype=torch.float32)
layer.to(self.device)
input_tensor = torch.tensor([[1.0, -1.0, -1.0, 1.0], [1.0, -1.0, 1.0, 1.0]], dtype=torch.float32).to(
torch_device
)
input_tensor = torch.tensor([1.0, -1.0, -1.0, 1.0], dtype=torch.float32).to(torch_device)
# Quantize the input tensor
quantized_tensor, scale = layer.activation_quant(input_tensor)
# Verify the output quantized tensor
self.assertEqual(quantized_tensor, input_tensor)
for i in range(input_tensor.shape[0]):
self.assertEqual(quantized_tensor[i] / scale, input_tensor[i])
# Verify the scale tensor
self.assertEqual(scale, 127)