mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Pipeline: fix unnecessary warnings (#35753)
* return attention mask * use correct model input name * fix * make
This commit is contained in:
parent
1750c518dd
commit
9c8d3a70b8
@ -64,7 +64,12 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
|
||||
for chunk_start_idx in range(0, inputs_len, step):
|
||||
chunk_end_idx = chunk_start_idx + chunk_len
|
||||
chunk = inputs[chunk_start_idx:chunk_end_idx]
|
||||
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
|
||||
processed = feature_extractor(
|
||||
chunk,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
if dtype is not None:
|
||||
processed = processed.to(dtype=dtype)
|
||||
_stride_left = 0 if chunk_start_idx == 0 else stride_left
|
||||
@ -507,11 +512,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if "generation_config" not in generate_kwargs:
|
||||
generate_kwargs["generation_config"] = self.generation_config
|
||||
|
||||
tokens = self.model.generate(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
main_input_name = self.model.main_input_name if hasattr(self.model, "main_input_name") else "inputs"
|
||||
generate_kwargs = {
|
||||
main_input_name: inputs,
|
||||
"attention_mask": attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
}
|
||||
tokens = self.model.generate(**generate_kwargs)
|
||||
|
||||
# whisper longform generation stores timestamps in "segments"
|
||||
if return_timestamps == "word" and self.type == "seq2seq_whisper":
|
||||
if "segments" not in tokens:
|
||||
|
Loading…
Reference in New Issue
Block a user