mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix FastSpeech2ConformerModelTest
and skip it on CPU (#28888)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
5346db1684
commit
6529a5b5c1
@ -1256,7 +1256,7 @@ class FastSpeech2ConformerModel(FastSpeech2ConformerPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_ids.shape)
|
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
||||||
|
|
||||||
has_missing_labels = (
|
has_missing_labels = (
|
||||||
spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None
|
spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None
|
||||||
|
@ -25,7 +25,7 @@ from transformers import (
|
|||||||
FastSpeech2ConformerWithHifiGanConfig,
|
FastSpeech2ConformerWithHifiGanConfig,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_g2p_en, require_torch, slow, torch_device
|
from transformers.testing_utils import require_g2p_en, require_torch, require_torch_accelerator, slow, torch_device
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
||||||
@ -117,6 +117,7 @@ class FastSpeech2ConformerModelTester:
|
|||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_accelerator
|
||||||
@require_torch
|
@require_torch
|
||||||
class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
|
class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (FastSpeech2ConformerModel,) if is_torch_available() else ()
|
all_model_classes = (FastSpeech2ConformerModel,) if is_torch_available() else ()
|
||||||
|
Loading…
Reference in New Issue
Block a user