Fix Bug in Flax Seq2Seq Models (#16021)

* Fix Bug in Flax Seq2Seq Models

* incorporate suggested changes
This commit is contained in:
Sanchit Gandhi 2022-03-10 14:58:05 +01:00 committed by GitHub
parent b7018abf3c
commit 741e49305d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 11 deletions

View File

@ -104,9 +104,9 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
[What are decoder input IDs?](../glossary#decoder-input-ids)
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
pre-training.
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
and prepending them with the `decoder_start_token_id`.
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
@ -169,9 +169,9 @@ ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
pre-training.
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
and prepending them with the `decoder_start_token_id`.
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
@ -670,6 +670,11 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# prepare decoder inputs
if decoder_input_ids is None:
raise ValueError(
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:

View File

@ -108,8 +108,9 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the
right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`.
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
and prepending them with the `decoder_start_token_id`.
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
@ -161,9 +162,9 @@ SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
pre-training.
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
and prepending them with the `decoder_start_token_id`.
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
@ -681,6 +682,10 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
attention_mask = jnp.ones_like(inputs)
# prepare decoder inputs
if decoder_input_ids is None:
raise ValueError(
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None: