Fix SpeechT5ForSpeechToSpeechIntegrationTests device issue (#21460)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-02-06 10:43:07 +01:00 committed by GitHub
parent 59d5edef34
commit 0db5d911fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -2869,7 +2869,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform.
"""
if speaker_embeddings is None:
speaker_embeddings = torch.zeros((1, 512))
speaker_embeddings = torch.zeros((1, 512), device=input_values.device)
return _generate_speech(
self,

View File

@ -1423,7 +1423,7 @@ class SpeechT5ForSpeechToSpeechIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(1)
input_values = processor(audio=input_speech, return_tensors="pt").input_values.to(torch_device)
speaker_embeddings = torch.zeros((1, 512))
speaker_embeddings = torch.zeros((1, 512), device=torch_device)
generated_speech = model.generate_speech(input_values, speaker_embeddings=speaker_embeddings)
self.assertEqual(generated_speech.shape[1], model.config.num_mel_bins)