Fix PersimmonIntegrationTest OOM (#26750)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-10-12 11:24:18 +02:00 committed by GitHub
parent ab0ddc99e8
commit 72256bc72a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Persimmon model. """ """ Testing suite for the PyTorch Persimmon model. """
import gc
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
@ -395,19 +396,27 @@ class PersimmonIntegrationTest(unittest.TestCase):
def test_model_8b_chat_logits(self): def test_model_8b_chat_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = PersimmonForCausalLM.from_pretrained( model = PersimmonForCausalLM.from_pretrained(
"adept/persimmon-8b-chat", device_map="auto", torch_dtype=torch.float16 "adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
) )
out = model(torch.tensor([input_ids])).logits out = model(torch.tensor([input_ids], device=torch_device)).logits
EXPECTED_MEAN = torch.tensor( EXPECTED_MEAN = torch.tensor(
[[-11.2879, -11.2628, -11.2498, -11.2534, -11.2676, -11.2638, -11.2501, -11.2431]], dtype=torch.float16 [[-11.4726, -11.1495, -11.2694, -11.2223, -10.9452, -11.0663, -11.0031, -11.1028]]
) )
torch.testing.assert_close(out.cpu().mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4) # change dtype to `torch.float32` before calling `mean` to avoid `nan` values
torch.testing.assert_close(out.cpu().to(torch.float32).mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4)
# fmt: off # fmt: off
EXPECTED_SLICE = torch.tensor([-16.9670, -16.9647, -16.9649, -16.9630, -16.9577, -16.9623, -17.0164, -16.9673, -16.9648, -16.9668, -17.0160, -16.9651, -17.0156, -16.9668, -16.9655, -16.9653, -16.9665, -16.9682, -17.0112, -16.9667, -16.9717, -16.9654, -16.9650, -16.9701, -16.9657, -17.0160, -16.9676, -17.0138, -16.9610, -16.9695]) EXPECTED_SLICE = torch.tensor(
[-16.9062, -16.9062, -16.9062, -16.9062, -16.8906, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062],
dtype=torch.float16
)
# fmt: on # fmt: on
torch.testing.assert_close(out.cpu()[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5) torch.testing.assert_close(out.cpu()[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
torch.cuda.empty_cache()
del model
gc.collect()
@slow @slow
@require_torch_gpu @require_torch_gpu
def test_model_8b_chat_greedy_generation(self): def test_model_8b_chat_greedy_generation(self):
@ -415,11 +424,15 @@ class PersimmonIntegrationTest(unittest.TestCase):
prompt = "human: Simply put, the theory of relativity states that?\n\nadept:" prompt = "human: Simply put, the theory of relativity states that?\n\nadept:"
tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-chat", use_fast=False) tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-chat", use_fast=False)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device) input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device)
model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-chat", torch_dtype=torch.float16).to( model = PersimmonForCausalLM.from_pretrained(
torch_device "adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
) )
# greedy generation outputs # greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=64) generated_ids = model.generate(input_ids, max_new_tokens=64)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text) self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
torch.cuda.empty_cache()
del model
gc.collect()