is_ctc needs to be updated to `self.type == "ctc". (#15194)

* `is_ctc` needs to be updated to `self.type == "ctc".

* Adding fast test for this functionality.
This commit is contained in:
Nicolas Patry 2022-01-18 12:20:10 +01:00 committed by GitHub
parent 32090c729f
commit dea563c943
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 1 deletions

View File

@ -215,7 +215,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate))
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate))
if not self.is_ctc:
if self.type != "ctc":
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
)

View File

@ -278,6 +278,23 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
@require_torch
def test_chunking_fast(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="hf-internal-testing/tiny-random-wav2vec2",
chunk_length_s=10.0,
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]
n_repeats = 2
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
@require_torch
@slow
def test_chunking(self):