mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
enable csm integration cases on xpu, all passed (#38140)
* enable csm test cases on XPU, all passed Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com>
This commit is contained in:
parent
e5a48785d9
commit
0173a99e73
@ -30,7 +30,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -430,7 +430,7 @@ class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
return ds[0]
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_1b_model_integration_generate(self):
|
||||
"""
|
||||
Tests the generated tokens match the ones from the original model implementation.
|
||||
@ -474,7 +474,7 @@ class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_1b_model_integration_generate_no_audio(self):
|
||||
"""
|
||||
Tests the generated tokens match the ones from the original model implementation.
|
||||
@ -535,7 +535,7 @@ class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_1b_model_integration_generate_multiple_audio(self):
|
||||
"""
|
||||
Test the generated tokens match the ones from the original model implementation.
|
||||
@ -594,7 +594,7 @@ class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_1b_model_integration_generate_batched(self):
|
||||
"""
|
||||
Test the generated tokens match the ones from the original model implementation.
|
||||
|
Loading…
Reference in New Issue
Block a user