mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix SpeechT5ForSpeechToSpeechIntegrationTests
device issue (#21460)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
59d5edef34
commit
0db5d911fc
@ -2869,7 +2869,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform.
|
||||
"""
|
||||
if speaker_embeddings is None:
|
||||
speaker_embeddings = torch.zeros((1, 512))
|
||||
speaker_embeddings = torch.zeros((1, 512), device=input_values.device)
|
||||
|
||||
return _generate_speech(
|
||||
self,
|
||||
|
@ -1423,7 +1423,7 @@ class SpeechT5ForSpeechToSpeechIntegrationTests(unittest.TestCase):
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_values = processor(audio=input_speech, return_tensors="pt").input_values.to(torch_device)
|
||||
|
||||
speaker_embeddings = torch.zeros((1, 512))
|
||||
speaker_embeddings = torch.zeros((1, 512), device=torch_device)
|
||||
generated_speech = model.generate_speech(input_values, speaker_embeddings=speaker_embeddings)
|
||||
|
||||
self.assertEqual(generated_speech.shape[1], model.config.num_mel_bins)
|
||||
|
Loading…
Reference in New Issue
Block a user