mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[Flax] Correct shift labels for seq2seq models in Flax (#12720)
* fix_torch_device_generate_test * remove @ * push * fix marian * fix * up
This commit is contained in:
parent
1a3deae820
commit
8244c5ad4f
@ -19,6 +19,8 @@ import random
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -212,15 +214,15 @@ BART_DECODE_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
|
||||
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||
shifted_input_ids = np.zeros_like(input_ids)
|
||||
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
||||
shifted_input_ids[:, 0] = decoder_start_token_id
|
||||
|
||||
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
|
@ -221,11 +221,11 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
|
||||
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||
shifted_input_ids = np.zeros_like(input_ids)
|
||||
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
||||
shifted_input_ids[:, 0] = decoder_start_token_id
|
||||
|
||||
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
|
@ -19,6 +19,8 @@ import random
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -217,20 +219,19 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray
|
||||
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
|
||||
have a single `decoder_start_token_id` in contrast to other Bart-like models.
|
||||
"""
|
||||
prev_output_tokens = jnp.array(input_ids).clone()
|
||||
prev_output_tokens = np.array(input_ids).copy()
|
||||
|
||||
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
|
||||
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids)
|
||||
index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
|
||||
decoder_start_tokens = jnp.array(
|
||||
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)]
|
||||
prev_output_tokens = np.where(prev_output_tokens == -100, pad_token_id, input_ids)
|
||||
index_of_eos = (np.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
|
||||
decoder_start_tokens = np.array(
|
||||
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=np.int32
|
||||
).squeeze()
|
||||
# for loop basically does jax-compatible version of prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
|
||||
for i in range(prev_output_tokens.shape[1], 0, -1):
|
||||
prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., i), prev_output_tokens[:, i - 1])
|
||||
prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., 0), decoder_start_tokens)
|
||||
|
||||
prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].copy()
|
||||
prev_output_tokens[:, 0] = decoder_start_tokens
|
||||
|
||||
return prev_output_tokens
|
||||
|
||||
|
@ -47,15 +47,16 @@ _CONFIG_FOR_DOC = "T5Config"
|
||||
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
||||
|
||||
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
|
||||
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||
shifted_input_ids = np.zeros_like(input_ids)
|
||||
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
||||
shifted_input_ids[:, 0] = decoder_start_token_id
|
||||
|
||||
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user