Proposal Remove the weird inspect in ASR pipeline and make WhisperEncoder just nice to use. (#19571)

* Proposal Remove the weird `inspect` in ASR pipeline and make
WhisperEncoder just nice to use.

It seems that accepting `attention_mask` is kind of an invariant of our
models. For Seq2Seq ASR models, we had a special comment on how it
actually was important to send it.

`inspecting` seems pretty brittle way to handle this case.
My suggestion is to simply add it as an kwarg that and just ignoring
it with the docstring explaining why it's ignored.

* Fixup.

* Update src/transformers/models/whisper/modeling_whisper.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Doc fixing .

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Nicolas Patry 2022-11-14 09:34:30 +01:00 committed by GitHub
parent 2308f3d42c
commit 03bc6ece1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 15 deletions

View File

@ -610,6 +610,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
@ -624,12 +625,14 @@ class WhisperEncoder(WhisperPreTrainedModel):
`input_features`, the [`WhisperFeatureExtractor`] should be used for extracting the mel features,
padding and conversion into a tensor of type `torch.FloatTensor`. See
[`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.Tensor`)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
but it is not used. By default the silence in the input log mel spectrogram are ignored.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Union
@ -289,10 +288,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
is_last = model_inputs.pop("is_last")
if self.type == "seq2seq":
encoder = self.model.get_encoder()
# we need to pass `processed.get("attention_mask")` here since audio encoder
# attention mask length is different from expected text decoder `encoder_attention_mask` length
# `generate` magic to create the mask automatically won't work, we basically need to help
# it here.
# Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone)
if "input_features" in model_inputs:
@ -305,15 +300,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
)
accepts_attention_mask = "attention_mask" in set(inspect.signature(encoder.forward).parameters.keys())
if accepts_attention_mask:
attention_mask = model_inputs.pop("attention_mask", None)
tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
)
else:
tokens = self.model.generate(inputs)
# we need to pass `processed.get("attention_mask")` here since audio encoder
# attention mask length is different from expected text decoder `encoder_attention_mask` length
# `generate` magic to create the mask automatically won't work, we basically need to help
# it here.
attention_mask = model_inputs.pop("attention_mask", None)
tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
)
out = {"tokens": tokens}