mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Enable granite speech 3.3 tests (#37560)
* Enable granite speech 3.3 tests * skip sdpa test for granite speech * Explicitly move model to device * Use granite speech 2b in tests --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
parent
031ef8802c
commit
06c4d05fe6
@ -33,6 +33,7 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_datasets_available,
|
||||
is_peft_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
@ -306,11 +307,17 @@ class GraniteSpeechForConditionalGenerationModelTest(ModelTesterMixin, Generatio
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@unittest.skip(reason="Granite Speech doesn't support SDPA for all backbones")
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
pass
|
||||
|
||||
|
||||
class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# TODO - use the actual model path on HF hub after release.
|
||||
self.model_path = "ibm-granite/granite-speech"
|
||||
self.model_path = "ibm-granite/granite-speech-3.3-2b"
|
||||
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
||||
self.prompt = self._get_prompt(self.processor.tokenizer)
|
||||
|
||||
@ -338,7 +345,7 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
@slow
|
||||
@pytest.mark.skip("Public models not yet available")
|
||||
@pytest.mark.skipif(not is_peft_available(), reason="Outputs diverge without lora")
|
||||
def test_small_model_integration_test_single(self):
|
||||
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
|
||||
input_speech = self._load_datasamples(1)
|
||||
@ -364,9 +371,9 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
@slow
|
||||
@pytest.mark.skip("Public models not yet available")
|
||||
@pytest.mark.skipif(not is_peft_available(), reason="Outputs diverge without lora")
|
||||
def test_small_model_integration_test_batch(self):
|
||||
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path)
|
||||
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
|
||||
input_speech = self._load_datasamples(2)
|
||||
prompts = [self.prompt, self.prompt]
|
||||
|
||||
@ -384,7 +391,7 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantmister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
||||
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilter's manner less interesting than his matter"
|
||||
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilp's manner less interesting than his matter"
|
||||
] # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
|
@ -33,14 +33,12 @@ if is_torchaudio_available():
|
||||
from transformers import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor
|
||||
|
||||
|
||||
@pytest.skip("Public models not yet available", allow_module_level=True)
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class GraniteSpeechProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
# TODO - use the actual model path on HF hub after release.
|
||||
self.checkpoint = "ibm-granite/granite-speech"
|
||||
self.checkpoint = "ibm-granite/granite-speech-3.3-8b"
|
||||
processor = GraniteSpeechProcessor.from_pretrained(self.checkpoint)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user