mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
fix UT failures on XPU w/ stock PyTorch 2.7 & 2.8 (#39116)
* fix UT failures on XPU w/ stock PyTorch 2.7 & 2.8 Signed-off-by: YAO Matrix <matrix.yao@intel.com> * zamba2 Signed-off-by: YAO Matrix <matrix.yao@intel.com> * xx Signed-off-by: YAO Matrix <matrix.yao@intel.com> * internvl Signed-off-by: YAO Matrix <matrix.yao@intel.com> * tp cases Signed-off-by: YAO Matrix <matrix.yao@intel.com> --------- Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
parent
ccf2ca162e
commit
2100ee6545
@ -24,6 +24,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, is_
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
is_flash_attn_2_available,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
@ -136,6 +137,9 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
||||
class Cohere2IntegrationTest(unittest.TestCase):
|
||||
input_text = ["Hello I am doing", "Hi today"]
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_model_bf16(self):
|
||||
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
|
||||
EXPECTED_TEXTS = [
|
||||
|
@ -29,6 +29,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
is_flaky,
|
||||
require_timm,
|
||||
require_torch,
|
||||
@ -804,34 +805,62 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**text_inputs, **image_inputs)
|
||||
|
||||
# Loss differs by CPU and GPU, also this can be changed in future.
|
||||
expected_loss_dict = {
|
||||
"loss_ce": torch.tensor(1.1147),
|
||||
"loss_bbox": torch.tensor(0.2031),
|
||||
"loss_giou": torch.tensor(0.5819),
|
||||
"loss_ce_0": torch.tensor(1.1941),
|
||||
"loss_bbox_0": torch.tensor(0.1978),
|
||||
"loss_giou_0": torch.tensor(0.5524),
|
||||
"loss_ce_1": torch.tensor(1.1621),
|
||||
"loss_bbox_1": torch.tensor(0.1909),
|
||||
"loss_giou_1": torch.tensor(0.5892),
|
||||
"loss_ce_2": torch.tensor(1.1641),
|
||||
"loss_bbox_2": torch.tensor(0.1892),
|
||||
"loss_giou_2": torch.tensor(0.5626),
|
||||
"loss_ce_3": torch.tensor(1.1943),
|
||||
"loss_bbox_3": torch.tensor(0.1941),
|
||||
"loss_giou_3": torch.tensor(0.5607),
|
||||
"loss_ce_4": torch.tensor(1.0956),
|
||||
"loss_bbox_4": torch.tensor(0.2008),
|
||||
"loss_giou_4": torch.tensor(0.5836),
|
||||
"loss_ce_enc": torch.tensor(16226.3164),
|
||||
"loss_bbox_enc": torch.tensor(0.3063),
|
||||
"loss_giou_enc": torch.tensor(0.7380),
|
||||
}
|
||||
# Loss differs by CPU and accelerator, also this can be changed in future.
|
||||
expected_loss_dicts = Expectations(
|
||||
{
|
||||
("xpu", 3): {
|
||||
"loss_ce": torch.tensor(1.1147),
|
||||
"loss_bbox": torch.tensor(0.2031),
|
||||
"loss_giou": torch.tensor(0.5819),
|
||||
"loss_ce_0": torch.tensor(1.1941),
|
||||
"loss_bbox_0": torch.tensor(0.1978),
|
||||
"loss_giou_0": torch.tensor(0.5524),
|
||||
"loss_ce_1": torch.tensor(1.1621),
|
||||
"loss_bbox_1": torch.tensor(0.1909),
|
||||
"loss_giou_1": torch.tensor(0.5892),
|
||||
"loss_ce_2": torch.tensor(1.1641),
|
||||
"loss_bbox_2": torch.tensor(0.1892),
|
||||
"loss_giou_2": torch.tensor(0.5626),
|
||||
"loss_ce_3": torch.tensor(1.1943),
|
||||
"loss_bbox_3": torch.tensor(0.1941),
|
||||
"loss_giou_3": torch.tensor(0.5592),
|
||||
"loss_ce_4": torch.tensor(1.0956),
|
||||
"loss_bbox_4": torch.tensor(0.2037),
|
||||
"loss_giou_4": torch.tensor(0.5813),
|
||||
"loss_ce_enc": torch.tensor(16226.3164),
|
||||
"loss_bbox_enc": torch.tensor(0.3063),
|
||||
"loss_giou_enc": torch.tensor(0.7380),
|
||||
},
|
||||
("cuda", None): {
|
||||
"loss_ce": torch.tensor(1.1147),
|
||||
"loss_bbox": torch.tensor(0.2031),
|
||||
"loss_giou": torch.tensor(0.5819),
|
||||
"loss_ce_0": torch.tensor(1.1941),
|
||||
"loss_bbox_0": torch.tensor(0.1978),
|
||||
"loss_giou_0": torch.tensor(0.5524),
|
||||
"loss_ce_1": torch.tensor(1.1621),
|
||||
"loss_bbox_1": torch.tensor(0.1909),
|
||||
"loss_giou_1": torch.tensor(0.5892),
|
||||
"loss_ce_2": torch.tensor(1.1641),
|
||||
"loss_bbox_2": torch.tensor(0.1892),
|
||||
"loss_giou_2": torch.tensor(0.5626),
|
||||
"loss_ce_3": torch.tensor(1.1943),
|
||||
"loss_bbox_3": torch.tensor(0.1941),
|
||||
"loss_giou_3": torch.tensor(0.5607),
|
||||
"loss_ce_4": torch.tensor(1.0956),
|
||||
"loss_bbox_4": torch.tensor(0.2008),
|
||||
"loss_giou_4": torch.tensor(0.5836),
|
||||
"loss_ce_enc": torch.tensor(16226.3164),
|
||||
"loss_bbox_enc": torch.tensor(0.3063),
|
||||
"loss_giou_enc": torch.tensor(0.7380),
|
||||
},
|
||||
}
|
||||
) # fmt: skip
|
||||
expected_loss_dict = expected_loss_dicts.get_expectation()
|
||||
|
||||
expected_loss = torch.tensor(32482.2305)
|
||||
|
||||
for key in expected_loss_dict:
|
||||
self.assertTrue(torch.allclose(outputs.loss_dict[key], expected_loss_dict[key], atol=1e-3))
|
||||
torch.testing.assert_close(outputs.loss_dict[key], expected_loss_dict[key], rtol=1e-5, atol=1e-3)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=1e-3))
|
||||
|
@ -30,6 +30,8 @@ from transformers import (
|
||||
InstructBlipVisionConfig,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
@ -722,6 +724,9 @@ def prepare_img():
|
||||
@require_torch
|
||||
@slow
|
||||
class InstructBlipModelIntegrationTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=False)
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
def test_inference_vicuna_7b(self):
|
||||
@ -739,13 +744,24 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model.generate(**inputs, max_new_tokens=30)
|
||||
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
||||
|
||||
expected_outputs = [32001] * 32 + [2, 1724, 338, 22910, 1048, 445, 1967, 29973, 450, 22910, 9565, 310, 445, 1967, 338, 393, 263, 767, 338, 13977, 292, 22095, 373, 278, 1250, 310, 263, 13328, 20134, 29963, 1550, 19500, 373, 263, 19587, 4272, 11952, 29889] # fmt: off
|
||||
expected_outputs = Expectations(
|
||||
{
|
||||
("xpu", 3): [32001] * 32 + [2, 1724, 338, 22910, 1048, 445, 1967, 29973, 450, 22910, 9565, 310, 445, 1967, 338, 393, 263, 767, 338, 13977, 292, 22095, 373, 278, 1250, 310, 263, 13328, 20134, 29963, 1550, 19500, 1623, 263, 19587, 4272, 11952, 29889],
|
||||
("cuda", None): [32001] * 32 + [2, 1724, 338, 22910, 1048, 445, 1967, 29973, 450, 22910, 9565, 310, 445, 1967, 338, 393, 263, 767, 338, 13977, 292, 22095, 373, 278, 1250, 310, 263, 13328, 20134, 29963, 1550, 19500, 373, 263, 19587, 4272, 11952, 29889],
|
||||
}
|
||||
) # fmt: off
|
||||
expected_output = expected_outputs.get_expectation()
|
||||
|
||||
self.assertEqual(outputs[0].tolist(), expected_outputs)
|
||||
self.assertEqual(
|
||||
generated_text,
|
||||
"What is unusual about this image? The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV while driving on a busy city street.",
|
||||
)
|
||||
expected_texts = Expectations(
|
||||
{
|
||||
("xpu", 3): "What is unusual about this image? The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV while driving down a busy city street.",
|
||||
("cuda", None): "What is unusual about this image? The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV while driving on a busy city street.",
|
||||
}
|
||||
) # fmt: off
|
||||
expected_text = expected_texts.get_expectation()
|
||||
|
||||
self.assertEqual(outputs[0].tolist(), expected_output)
|
||||
self.assertEqual(generated_text, expected_text)
|
||||
|
||||
def test_inference_flant5_xl(self):
|
||||
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
|
||||
|
@ -430,7 +430,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
|
||||
expected_outputs = Expectations(
|
||||
{
|
||||
("xpu", 3): 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate"',
|
||||
("xpu", 3): 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of',
|
||||
("cuda", 7): 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of',
|
||||
}
|
||||
) # fmt: skip
|
||||
@ -793,7 +793,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
expected_outputs = Expectations(
|
||||
{
|
||||
("xpu", 3): "user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden path leads to calm lake,\nNature's peaceful grace.",
|
||||
("xpu", 3): "user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.",
|
||||
("cuda", 7): 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.',
|
||||
("cuda", 8): 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.',
|
||||
}
|
||||
|
@ -17,6 +17,8 @@ import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
require_read_token,
|
||||
require_torch_large_accelerator,
|
||||
slow,
|
||||
@ -78,10 +80,17 @@ class Llama4IntegrationTest(unittest.TestCase):
|
||||
},
|
||||
]
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_model_17b_16e_fp16(self):
|
||||
EXPECTED_TEXT = [
|
||||
'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white'
|
||||
] # fmt: skip
|
||||
EXPECTED_TEXTS = Expectations(
|
||||
{
|
||||
("xpu", 3): ['system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach with a blue sky and a body of water in the background. The cow is brown with a white face'],
|
||||
("cuda", None): ['system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white'],
|
||||
}
|
||||
) # fmt: skip
|
||||
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
self.messages_1, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True
|
||||
|
@ -22,6 +22,7 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, Zamba2Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
@ -678,14 +679,23 @@ class Zamba2ModelIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
, dtype=torch.float32) # fmt: skip
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
|
||||
[
|
||||
0.1966, 6.3449, 3.8350, -5.7291, -6.5106, -6.5104, -6.5103, -6.5104,
|
||||
-6.5103, -6.5104, -6.5106, -6.5105, 7.8700, 13.5434, -6.5104, -6.5096,
|
||||
-6.5106, -6.5102, -6.5106, -6.5106, -6.5105, -6.5106, -6.5104, -6.5106,
|
||||
-6.5105, -6.5106, -6.5106, -6.5113, -6.5102, -6.5105, -6.5108, -6.5105,
|
||||
-6.5104, -6.5106, -6.5106, -6.5104, -6.5106, -6.5107, -6.5103, -6.5105 ]
|
||||
, dtype=torch.float32) # fmt: skip
|
||||
EXPECTED_LOGITS_NO_GRAD_1S = Expectations(
|
||||
{
|
||||
("xpu", 3): torch.tensor([0.2027, 6.3481, 3.8392, -5.7279, -6.5090, -6.5088, -6.5087, -6.5088,
|
||||
-6.5087, -6.5088, -6.5090, -6.5089, 7.8796, 13.5483, -6.5088, -6.5080,
|
||||
-6.5090, -6.5086, -6.5090, -6.5090, -6.5089, -6.5090, -6.5088, -6.5090,
|
||||
-6.5089, -6.5090, -6.5090, -6.5097, -6.5086, -6.5089, -6.5092, -6.5089,
|
||||
-6.5088, -6.5090, -6.5090, -6.5088, -6.5090, -6.5091, -6.5087, -6.5089],
|
||||
dtype=torch.float32),
|
||||
("cuda", None): torch.tensor([0.1966, 6.3449, 3.8350, -5.7291, -6.5106, -6.5104, -6.5103, -6.5104,
|
||||
-6.5103, -6.5104, -6.5106, -6.5105, 7.8700, 13.5434, -6.5104, -6.5096,
|
||||
-6.5106, -6.5102, -6.5106, -6.5106, -6.5105, -6.5106, -6.5104, -6.5106,
|
||||
-6.5105, -6.5106, -6.5106, -6.5113, -6.5102, -6.5105, -6.5108, -6.5105,
|
||||
-6.5104, -6.5106, -6.5106, -6.5104, -6.5106, -6.5107, -6.5103, -6.5105],
|
||||
dtype=torch.float32),
|
||||
}
|
||||
) # fmt: skip
|
||||
EXPECTED_LOGITS_NO_GRAD_1 = EXPECTED_LOGITS_NO_GRAD_1S.get_expectation()
|
||||
|
||||
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
|
@ -520,14 +520,14 @@ class Pipeline4BitTest(Base4bitTest):
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
@apply_skip_if_not_implemented
|
||||
class Bnb4bitTestMultiGpu(Base4bitTest):
|
||||
class Bnb4bitTestMultiAccelerator(Base4bitTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def test_multi_gpu_loading(self):
|
||||
def test_multi_accelerator_loading(self):
|
||||
r"""
|
||||
This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
|
||||
Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
|
||||
This tests that the model has been loaded and can be used correctly on a multi-accelerator setup.
|
||||
Let's just try to load a model on 2 accelerators and see if it works. The model we test has ~2GB of total, 3GB should suffice
|
||||
"""
|
||||
device_map = {
|
||||
"transformer.word_embeddings": 0,
|
||||
|
@ -24,7 +24,7 @@ from transformers.testing_utils import (
|
||||
backend_device_count,
|
||||
get_torch_dist_unique_port,
|
||||
require_huggingface_hub_greater_or_equal,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@ -168,6 +168,6 @@ class TestTensorParallel(TestCasePlus):
|
||||
del non_tp_tensor, tp_tensor
|
||||
|
||||
|
||||
@require_torch_multi_gpu
|
||||
class TestTensorParallelCuda(TestTensorParallel):
|
||||
@require_torch_multi_accelerator
|
||||
class TestTensorParallelAccelerator(TestTensorParallel):
|
||||
nproc_per_node = backend_device_count(torch_device)
|
||||
|
Loading…
Reference in New Issue
Block a user