[Flax T5] Fix weight initialization and fix docs (#12327)

* finish t5 flax fixes

* improve naming
This commit is contained in:
Patrick von Platen 2021-06-23 17:39:21 +01:00 committed by GitHub
parent 12a4457c56
commit 468cda20f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -84,8 +84,21 @@ class FlaxT5DenseReluDense(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.wi = nn.Dense(self.config.d_ff, use_bias=False, dtype=self.dtype)
self.wo = nn.Dense(self.config.d_model, use_bias=False, dtype=self.dtype)
wi_init_std = self.config.initializer_factor * (self.config.d_model ** -0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff ** -0.5)
self.wi = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
dtype=self.dtype,
)
self.wo = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype),
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(self, hidden_states, deterministic=True):
@ -101,9 +114,27 @@ class FlaxT5DenseGatedGeluDense(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.wi_0 = nn.Dense(self.config.d_ff, use_bias=False, dtype=self.dtype)
self.wi_1 = nn.Dense(self.config.d_ff, use_bias=False, dtype=self.dtype)
self.wo = nn.Dense(self.config.d_model, use_bias=False, dtype=self.dtype)
wi_init_std = self.config.initializer_factor * (self.config.d_model ** -0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff ** -0.5)
self.wi_0 = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
dtype=self.dtype,
)
self.wi_1 = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
dtype=self.dtype,
)
self.wo = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype),
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
@ -154,14 +185,40 @@ class FlaxT5Attention(nn.Module):
self.dropout = self.config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.q = nn.Dense(self.inner_dim, use_bias=False, dtype=self.dtype)
self.k = nn.Dense(self.inner_dim, use_bias=False, dtype=self.dtype)
self.v = nn.Dense(self.inner_dim, use_bias=False, dtype=self.dtype)
self.o = nn.Dense(self.d_model, use_bias=False, dtype=self.dtype)
inner_dim_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5)
d_model_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5)
self.q = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
dtype=self.dtype,
)
self.k = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
dtype=self.dtype,
)
self.v = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
dtype=self.dtype,
)
self.o = nn.Dense(
self.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(inner_dim_init_std, self.dtype),
dtype=self.dtype,
)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embed(
self.relative_attention_num_buckets, self.n_heads, dtype=self.dtype
self.relative_attention_num_buckets,
self.n_heads,
embedding_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
dtype=self.dtype,
)
@staticmethod
@ -246,7 +303,8 @@ class FlaxT5Attention(nn.Module):
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions
# that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
@ -488,7 +546,6 @@ class FlaxT5Block(nn.Module):
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
cross_attn_layer_head_mask=None,
output_attentions=False,
return_dict=True,
deterministic=True,
@ -527,7 +584,9 @@ class FlaxT5Block(nn.Module):
outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
# returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights),
# (cross-attention position bias), (cross-attention weights)
return outputs
class FlaxT5LayerCollection(nn.Module):
@ -548,7 +607,6 @@ class FlaxT5LayerCollection(nn.Module):
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
cross_attn_layer_head_mask=None,
output_attentions=False,
return_dict=True,
deterministic=True,
@ -713,7 +771,7 @@ class FlaxT5Stack(nn.Module):
T5_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
should be able to pad the inputs on both the right and the left.
@ -723,23 +781,13 @@ T5_ENCODE_INPUTS_DOCSTRING = r"""
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
<./t5.html#training>`__.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
@ -838,7 +886,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_input_ids: jnp.ndarray = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@ -853,6 +901,11 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if decoder_input_ids is None:
raise ValueError(
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
)
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
@ -1078,24 +1131,31 @@ T5_START_DOCSTRING = r"""
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text
denoising generative setting.
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.)
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
general usage and behavior.
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
Parameters:
config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""
T5_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
should be able to pad the inputs on both the right and the left.
@ -1107,14 +1167,14 @@ T5_INPUTS_DOCSTRING = r"""
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
<./t5.html#training>`__.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
@ -1129,53 +1189,20 @@ T5_INPUTS_DOCSTRING = r"""
To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
<./t5.html#training>`__.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
the decoder.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
:obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
takes the value of :obj:`inputs_embeds`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
@ -1242,7 +1269,7 @@ class FlaxT5Module(nn.Module):
Example::
>>> from transformers import T5Tokenizer, T5Model
>>> from transformers import T5Tokenizer, FlaxT5Model
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = FlaxT5Model.from_pretrained('t5-small')
@ -1310,7 +1337,11 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
def setup(self):
self.model_dim = self.config.d_model
self.shared = nn.Embed(self.config.vocab_size, self.config.d_model)
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype),
)
encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
@ -1324,13 +1355,12 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxT5Stack(decoder_config, self.shared)
self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False)
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype),
dtype=self.dtype,
)
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@ -1361,12 +1391,12 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> logits = outputs.logits
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids # Batch size 1
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids
>>> outputs = model.generate(input_ids)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode if needed (training, first prediction pass)
# Encode
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,