[Flax] Correct shift labels for seq2seq models in Flax (#12720)

* fix_torch_device_generate_test

* remove @

* push

* fix marian

* fix

* up
This commit is contained in:
Patrick von Platen 2021-07-15 07:42:36 +01:00 committed by GitHub
parent 1a3deae820
commit 8244c5ad4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 23 deletions

View File

@ -19,6 +19,8 @@ import random
from functools import partial
from typing import Callable, Optional, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
@ -212,15 +214,15 @@ BART_DECODE_INPUTS_DOCSTRING = r"""
"""
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids

View File

@ -221,11 +221,11 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids

View File

@ -19,6 +19,8 @@ import random
from functools import partial
from typing import Callable, Optional, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
@ -217,20 +219,19 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
have a single `decoder_start_token_id` in contrast to other Bart-like models.
"""
prev_output_tokens = jnp.array(input_ids).clone()
prev_output_tokens = np.array(input_ids).copy()
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids)
index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
decoder_start_tokens = jnp.array(
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)]
prev_output_tokens = np.where(prev_output_tokens == -100, pad_token_id, input_ids)
index_of_eos = (np.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
decoder_start_tokens = np.array(
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=np.int32
).squeeze()
# for loop basically does jax-compatible version of prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
for i in range(prev_output_tokens.shape[1], 0, -1):
prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., i), prev_output_tokens[:, i - 1])
prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., 0), decoder_start_tokens)
prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].copy()
prev_output_tokens[:, 0] = decoder_start_tokens
return prev_output_tokens

View File

@ -47,15 +47,16 @@ _CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids