mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flax T5] Fix weight initialization and fix docs (#12327)
* finish t5 flax fixes * improve naming
This commit is contained in:
parent
12a4457c56
commit
468cda20f2
@ -84,8 +84,21 @@ class FlaxT5DenseReluDense(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.wi = nn.Dense(self.config.d_ff, use_bias=False, dtype=self.dtype)
|
||||
self.wo = nn.Dense(self.config.d_model, use_bias=False, dtype=self.dtype)
|
||||
wi_init_std = self.config.initializer_factor * (self.config.d_model ** -0.5)
|
||||
wo_init_std = self.config.initializer_factor * (self.config.d_ff ** -0.5)
|
||||
|
||||
self.wi = nn.Dense(
|
||||
self.config.d_ff,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.wo = nn.Dense(
|
||||
self.config.d_model,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(self.config.dropout_rate)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
@ -101,9 +114,27 @@ class FlaxT5DenseGatedGeluDense(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.wi_0 = nn.Dense(self.config.d_ff, use_bias=False, dtype=self.dtype)
|
||||
self.wi_1 = nn.Dense(self.config.d_ff, use_bias=False, dtype=self.dtype)
|
||||
self.wo = nn.Dense(self.config.d_model, use_bias=False, dtype=self.dtype)
|
||||
wi_init_std = self.config.initializer_factor * (self.config.d_model ** -0.5)
|
||||
wo_init_std = self.config.initializer_factor * (self.config.d_ff ** -0.5)
|
||||
|
||||
self.wi_0 = nn.Dense(
|
||||
self.config.d_ff,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.wi_1 = nn.Dense(
|
||||
self.config.d_ff,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.wo = nn.Dense(
|
||||
self.config.d_model,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(self.config.dropout_rate)
|
||||
self.gelu_act = ACT2FN["gelu_new"]
|
||||
|
||||
@ -154,14 +185,40 @@ class FlaxT5Attention(nn.Module):
|
||||
self.dropout = self.config.dropout_rate
|
||||
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
||||
|
||||
self.q = nn.Dense(self.inner_dim, use_bias=False, dtype=self.dtype)
|
||||
self.k = nn.Dense(self.inner_dim, use_bias=False, dtype=self.dtype)
|
||||
self.v = nn.Dense(self.inner_dim, use_bias=False, dtype=self.dtype)
|
||||
self.o = nn.Dense(self.d_model, use_bias=False, dtype=self.dtype)
|
||||
inner_dim_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5)
|
||||
d_model_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5)
|
||||
|
||||
self.q = nn.Dense(
|
||||
self.inner_dim,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.k = nn.Dense(
|
||||
self.inner_dim,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.v = nn.Dense(
|
||||
self.inner_dim,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.o = nn.Dense(
|
||||
self.d_model,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(inner_dim_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
if self.has_relative_attention_bias:
|
||||
self.relative_attention_bias = nn.Embed(
|
||||
self.relative_attention_num_buckets, self.n_heads, dtype=self.dtype
|
||||
self.relative_attention_num_buckets,
|
||||
self.n_heads,
|
||||
embedding_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -246,7 +303,8 @@ class FlaxT5Attention(nn.Module):
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions
|
||||
# that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
@ -488,7 +546,6 @@ class FlaxT5Block(nn.Module):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
encoder_decoder_position_bias=None,
|
||||
cross_attn_layer_head_mask=None,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
deterministic=True,
|
||||
@ -527,7 +584,9 @@ class FlaxT5Block(nn.Module):
|
||||
|
||||
outputs = outputs + attention_outputs
|
||||
|
||||
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
||||
# returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights),
|
||||
# (cross-attention position bias), (cross-attention weights)
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxT5LayerCollection(nn.Module):
|
||||
@ -548,7 +607,6 @@ class FlaxT5LayerCollection(nn.Module):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
encoder_decoder_position_bias=None,
|
||||
cross_attn_layer_head_mask=None,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
deterministic=True,
|
||||
@ -713,7 +771,7 @@ class FlaxT5Stack(nn.Module):
|
||||
|
||||
T5_ENCODE_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
||||
should be able to pad the inputs on both the right and the left.
|
||||
|
||||
@ -723,23 +781,13 @@ T5_ENCODE_INPUTS_DOCSTRING = r"""
|
||||
|
||||
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
|
||||
<./t5.html#training>`__.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
@ -838,7 +886,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
||||
self,
|
||||
input_ids: jnp.ndarray,
|
||||
attention_mask: Optional[jnp.ndarray] = None,
|
||||
decoder_input_ids: Optional[jnp.ndarray] = None,
|
||||
decoder_input_ids: jnp.ndarray = None,
|
||||
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
@ -853,6 +901,11 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if decoder_input_ids is None:
|
||||
raise ValueError(
|
||||
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
|
||||
)
|
||||
|
||||
# prepare encoder inputs
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
@ -1078,24 +1131,31 @@ T5_START_DOCSTRING = r"""
|
||||
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text
|
||||
denoising generative setting.
|
||||
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||
generic methods the library implements for all its model (such as downloading or saving, resizing the input
|
||||
embeddings, pruning heads etc.)
|
||||
|
||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
||||
general usage and behavior.
|
||||
This model is also a Flax Linen `flax.nn.Module
|
||||
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
|
||||
Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
|
||||
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
|
||||
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
|
||||
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
|
||||
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
||||
weights.
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
"""
|
||||
|
||||
T5_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
||||
should be able to pad the inputs on both the right and the left.
|
||||
|
||||
@ -1107,14 +1167,14 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
|
||||
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
|
||||
<./t5.html#training>`__.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
||||
@ -1129,53 +1189,20 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
|
||||
To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
|
||||
<./t5.html#training>`__.
|
||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
|
||||
``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
|
||||
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
||||
sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
|
||||
the decoder.
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
|
||||
representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
|
||||
have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
|
||||
:obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
||||
|
||||
If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
|
||||
takes the value of :obj:`inputs_embeds`.
|
||||
|
||||
use_cache (:obj:`bool`, `optional`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
@ -1242,7 +1269,7 @@ class FlaxT5Module(nn.Module):
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import T5Tokenizer, T5Model
|
||||
>>> from transformers import T5Tokenizer, FlaxT5Model
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = FlaxT5Model.from_pretrained('t5-small')
|
||||
@ -1310,7 +1337,11 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
|
||||
def setup(self):
|
||||
self.model_dim = self.config.d_model
|
||||
|
||||
self.shared = nn.Embed(self.config.vocab_size, self.config.d_model)
|
||||
self.shared = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype),
|
||||
)
|
||||
|
||||
encoder_config = copy.deepcopy(self.config)
|
||||
encoder_config.causal = False
|
||||
@ -1324,13 +1355,12 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
|
||||
decoder_config.num_layers = self.config.num_decoder_layers
|
||||
self.decoder = FlaxT5Stack(decoder_config, self.shared)
|
||||
|
||||
self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
self.lm_head = nn.Dense(
|
||||
self.config.vocab_size,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1361,12 +1391,12 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
>>> logits = outputs.logits
|
||||
|
||||
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids # Batch size 1
|
||||
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids
|
||||
>>> outputs = model.generate(input_ids)
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
# Encode
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
Loading…
Reference in New Issue
Block a user