Pipeline: fix unnecessary warnings (#35753)

* return attention mask

* use correct model input name

* fix

* make
This commit is contained in:
eustlb 2025-06-27 14:32:03 +02:00 committed by GitHub
parent 1750c518dd
commit 9c8d3a70b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -64,7 +64,12 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
for chunk_start_idx in range(0, inputs_len, step):
chunk_end_idx = chunk_start_idx + chunk_len
chunk = inputs[chunk_start_idx:chunk_end_idx]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
processed = feature_extractor(
chunk,
sampling_rate=feature_extractor.sampling_rate,
return_tensors="pt",
return_attention_mask=True,
)
if dtype is not None:
processed = processed.to(dtype=dtype)
_stride_left = 0 if chunk_start_idx == 0 else stride_left
@ -507,11 +512,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
tokens = self.model.generate(
inputs=inputs,
attention_mask=attention_mask,
main_input_name = self.model.main_input_name if hasattr(self.model, "main_input_name") else "inputs"
generate_kwargs = {
main_input_name: inputs,
"attention_mask": attention_mask,
**generate_kwargs,
)
}
tokens = self.model.generate(**generate_kwargs)
# whisper longform generation stores timestamps in "segments"
if return_timestamps == "word" and self.type == "seq2seq_whisper":
if "segments" not in tokens: