Enable granite speech 3.3 tests (#37560)

* Enable granite speech 3.3 tests

* skip sdpa test for granite speech

* Explicitly move model to device

* Use granite speech 2b in tests

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Alex Brooks 2025-05-06 09:56:18 -06:00 committed by GitHub
parent 031ef8802c
commit 06c4d05fe6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 9 deletions

View File

@ -33,6 +33,7 @@ from transformers.testing_utils import (
)
from transformers.utils import (
is_datasets_available,
is_peft_available,
is_torch_available,
)
@ -306,11 +307,17 @@ class GraniteSpeechForConditionalGenerationModelTest(ModelTesterMixin, Generatio
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
@pytest.mark.generate
@require_torch_sdpa
@slow
@unittest.skip(reason="Granite Speech doesn't support SDPA for all backbones")
def test_eager_matches_sdpa_generate(self):
pass
class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
# TODO - use the actual model path on HF hub after release.
self.model_path = "ibm-granite/granite-speech"
self.model_path = "ibm-granite/granite-speech-3.3-2b"
self.processor = AutoProcessor.from_pretrained(self.model_path)
self.prompt = self._get_prompt(self.processor.tokenizer)
@ -338,7 +345,7 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
return [x["array"] for x in speech_samples]
@slow
@pytest.mark.skip("Public models not yet available")
@pytest.mark.skipif(not is_peft_available(), reason="Outputs diverge without lora")
def test_small_model_integration_test_single(self):
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
input_speech = self._load_datasamples(1)
@ -364,9 +371,9 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
)
@slow
@pytest.mark.skip("Public models not yet available")
@pytest.mark.skipif(not is_peft_available(), reason="Outputs diverge without lora")
def test_small_model_integration_test_batch(self):
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path)
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
input_speech = self._load_datasamples(2)
prompts = [self.prompt, self.prompt]
@ -384,7 +391,7 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = [
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantmister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilter's manner less interesting than his matter"
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilp's manner less interesting than his matter"
] # fmt: skip
self.assertEqual(

View File

@ -33,14 +33,12 @@ if is_torchaudio_available():
from transformers import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor
@pytest.skip("Public models not yet available", allow_module_level=True)
@require_torch
@require_torchaudio
class GraniteSpeechProcessorTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
# TODO - use the actual model path on HF hub after release.
self.checkpoint = "ibm-granite/granite-speech"
self.checkpoint = "ibm-granite/granite-speech-3.3-8b"
processor = GraniteSpeechProcessor.from_pretrained(self.checkpoint)
processor.save_pretrained(self.tmpdirname)