[TTA Pipeline] Fix MusicGen test (#26348)

* fix musicgen pipeline test

* fix wav2vec2 doctest

* revert wav2vec2
This commit is contained in:
Sanchit Gandhi 2023-09-22 16:55:54 +01:00 committed by GitHub
parent 368a58e61c
commit 914771cbfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -71,13 +71,13 @@ class TextToAudioPipeline(Pipeline):
if self.sampling_rate is None:
# get sampling_rate from config and generation config
config = self.model.config.to_dict()
config = self.model.config
gen_config = self.model.__dict__.get("generation_config", None)
if gen_config is not None:
config.update(gen_config.to_dict())
for sampling_rate_name in ["sample_rate", "sampling_rate"]:
sampling_rate = config.get(sampling_rate_name, None)
sampling_rate = getattr(config, sampling_rate_name, None)
if sampling_rate is not None:
self.sampling_rate = sampling_rate