mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix Bug in Flax Seq2Seq Models (#16021)
* Fix Bug in Flax Seq2Seq Models * incorporate suggested changes
This commit is contained in:
parent
b7018abf3c
commit
741e49305d
@ -104,9 +104,9 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||
|
||||
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
|
||||
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
|
||||
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
|
||||
pre-training.
|
||||
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
|
||||
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
|
||||
and prepending them with the `decoder_start_token_id`.
|
||||
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
@ -169,9 +169,9 @@ ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
|
||||
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
|
||||
pre-training.
|
||||
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
|
||||
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
|
||||
and prepending them with the `decoder_start_token_id`.
|
||||
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
|
||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||
@ -670,6 +670,11 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
||||
|
||||
# prepare decoder inputs
|
||||
if decoder_input_ids is None:
|
||||
raise ValueError(
|
||||
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
|
||||
)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
||||
if decoder_position_ids is None:
|
||||
|
@ -108,8 +108,9 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the
|
||||
right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`.
|
||||
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
|
||||
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
|
||||
and prepending them with the `decoder_start_token_id`.
|
||||
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
@ -161,9 +162,9 @@ SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
|
||||
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
|
||||
pre-training.
|
||||
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
|
||||
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
|
||||
and prepending them with the `decoder_start_token_id`.
|
||||
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
|
||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||
@ -681,6 +682,10 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
attention_mask = jnp.ones_like(inputs)
|
||||
|
||||
# prepare decoder inputs
|
||||
if decoder_input_ids is None:
|
||||
raise ValueError(
|
||||
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
|
||||
)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
||||
if decoder_position_ids is None:
|
||||
|
Loading…
Reference in New Issue
Block a user