mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add T5 Encoder for Feature Extraction (#8717)
* Add T5 Encoder class for feature extraction * fix T5 encoder add_start_docstrings indent * update init with T5 encoder * update init with TFT5ModelEncoder * remove TFT5ModelEncoder * change T5ModelEncoder order in init * add T5ModelEncoder to transformers init * clean T5ModelEncoder * update init with TFT5ModelEncoder * add TFModelEncoder for Tensorflow * update init with TFT5ModelEncoder * Update src/transformers/models/t5/modeling_t5.py change output from Seq2SeqModelOutput to BaseModelOutput Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * remove encoder_outputs 1. remove encoder_outputs from the function call. 2. remove the encoder_outputs If statement. 3. remove isinstance from return_dict. * Authorize missing decoder keys * remove unnecessary input parameters remove pask_key_values and use_cache * remove use_cache remove use_cache from the forward method * add doctoring for T5 encoder add doctoring for T5 encoder with T5_ENCODER_INPUTS_DOCSTRING * change return_dict to dot access * add T5_ENCODER_INPUTS_DOCSTRING for TF T5 * change TFT5Encoder output type to BaseModelOutput * remove unnecessary parameters for TFT5Encoder * remove unnecessary if statement * add import BaseModelOutput * fix BaseModelOutput typo to TFBaseModelOutput * update T5 doc with T5ModelEncoder * add T5ModelEncoder to tests * finish pytorch * finish docs and mt5 * add mtf to init * fix init * remove n_positions * finish PR * Update src/transformers/models/mt5/modeling_mt5.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/models/t5/modeling_t5.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/models/t5/modeling_tf_t5.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/models/mt5/modeling_tf_mt5.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * make style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
610cb106a2
commit
40ecaf0c2b
@ -39,6 +39,13 @@ MT5ForConditionalGeneration
|
||||
:members:
|
||||
|
||||
|
||||
MT5EncoderModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MT5EncoderModel
|
||||
:members:
|
||||
|
||||
|
||||
TFMT5Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -51,3 +58,10 @@ TFMT5ForConditionalGeneration
|
||||
|
||||
.. autoclass:: transformers.TFMT5ForConditionalGeneration
|
||||
:members:
|
||||
|
||||
|
||||
TFMT5EncoderModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMT5EncoderModel
|
||||
:members:
|
||||
|
@ -108,6 +108,11 @@ T5ForConditionalGeneration
|
||||
.. autoclass:: transformers.T5ForConditionalGeneration
|
||||
:members: forward, parallelize, deparallelize
|
||||
|
||||
T5EncoderModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.T5EncoderModel
|
||||
:members: forward
|
||||
|
||||
TFT5Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -121,3 +126,9 @@ TFT5ForConditionalGeneration
|
||||
|
||||
.. autoclass:: transformers.TFT5ForConditionalGeneration
|
||||
:members: call
|
||||
|
||||
TFT5EncoderModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFT5EncoderModel
|
||||
:members: call
|
||||
|
@ -506,7 +506,7 @@ if is_torch_available():
|
||||
MobileBertPreTrainedModel,
|
||||
load_tf_weights_in_mobilebert,
|
||||
)
|
||||
from .models.mt5 import MT5ForConditionalGeneration, MT5Model
|
||||
from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
|
||||
from .models.openai import (
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
OpenAIGPTDoubleHeadsModel,
|
||||
@ -561,6 +561,7 @@ if is_torch_available():
|
||||
)
|
||||
from .models.t5 import (
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5Model,
|
||||
T5PreTrainedModel,
|
||||
@ -803,7 +804,7 @@ if is_tf_available():
|
||||
TFMobileBertModel,
|
||||
TFMobileBertPreTrainedModel,
|
||||
)
|
||||
from .models.mt5 import TFMT5ForConditionalGeneration, TFMT5Model
|
||||
from .models.mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
|
||||
from .models.openai import (
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFOpenAIGPTDoubleHeadsModel,
|
||||
@ -826,6 +827,7 @@ if is_tf_available():
|
||||
)
|
||||
from .models.t5 import (
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFT5EncoderModel,
|
||||
TFT5ForConditionalGeneration,
|
||||
TFT5Model,
|
||||
TFT5PreTrainedModel,
|
||||
|
@ -7,7 +7,7 @@ from .configuration_mt5 import MT5Config
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_mt5 import MT5ForConditionalGeneration, MT5Model
|
||||
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
|
||||
from .modeling_tf_mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
|
||||
|
@ -15,7 +15,7 @@
|
||||
""" PyTorch mT5 model. """
|
||||
|
||||
from ...utils import logging
|
||||
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
|
||||
from ..t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model
|
||||
from .configuration_mt5 import MT5Config
|
||||
|
||||
|
||||
@ -73,11 +73,33 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
|
||||
config_class = MT5Config
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"lm_head\.weight",
|
||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
]
|
||||
|
||||
|
||||
class MT5EncoderModel(T5EncoderModel):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.T5EncoderModel`. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MT5EncoderModel, T5Tokenizer
|
||||
>>> model = MT5EncoderModel.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> input_ids = tokenizer(article, return_tensors="pt").input_ids
|
||||
>>> outputs = model(input_ids)
|
||||
>>> hidden_state = outputs.last_hidden_state
|
||||
"""
|
||||
|
||||
model_type = "mt5"
|
||||
config_class = MT5Config
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
]
|
||||
|
@ -15,7 +15,7 @@
|
||||
""" Tensorflow mT5 model. """
|
||||
|
||||
from ...utils import logging
|
||||
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
|
||||
from ..t5.modeling_tf_t5 import TFT5EncoderModel, TFT5ForConditionalGeneration, TFT5Model
|
||||
from .configuration_mt5 import MT5Config
|
||||
|
||||
|
||||
@ -64,3 +64,23 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
|
||||
|
||||
model_type = "mt5"
|
||||
config_class = MT5Config
|
||||
|
||||
|
||||
class TFMT5EncoderModel(TFT5EncoderModel):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.TFT5EncoderModel`. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import TFMT5EncoderModel, T5Tokenizer
|
||||
>>> model = TFMT5EncoderModel.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> input_ids = tokenizer(article, return_tensors="tf").input_ids
|
||||
>>> outputs = model(input_ids)
|
||||
>>> hidden_state = outputs.last_hidden_state
|
||||
"""
|
||||
|
||||
model_type = "mt5"
|
||||
config_class = MT5Config
|
||||
|
@ -15,6 +15,7 @@ if is_tokenizers_available():
|
||||
if is_torch_available():
|
||||
from .modeling_t5 import (
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5Model,
|
||||
T5PreTrainedModel,
|
||||
@ -24,6 +25,7 @@ if is_torch_available():
|
||||
if is_tf_available():
|
||||
from .modeling_tf_t5 import (
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFT5EncoderModel,
|
||||
TFT5ForConditionalGeneration,
|
||||
TFT5Model,
|
||||
TFT5PreTrainedModel,
|
||||
|
@ -700,7 +700,7 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
factor = self.config.initializer_factor # Used for testing weights initialization
|
||||
if isinstance(module, T5LayerNorm):
|
||||
module.weight.data.fill_(factor * 1.0)
|
||||
elif isinstance(module, (T5Model, T5ForConditionalGeneration)):
|
||||
elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
|
||||
# Mesh TensorFlow embeddings initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
@ -1082,6 +1082,45 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
T5_ENCODER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
||||
should be able to pad the inputs on both the right and the left.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
detail.
|
||||
|
||||
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
|
||||
<./t5.html#training>`__.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
|
||||
@ -1518,3 +1557,80 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
||||
return reordered_decoder_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare T5 Model transformer outputting encoder's raw hidden-states" "without any specific head on top.",
|
||||
T5_START_DOCSTRING,
|
||||
)
|
||||
class T5EncoderModel(T5PreTrainedModel):
|
||||
authorized_missing_keys = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__(config)
|
||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
encoder_config.is_encoder_decoder = False
|
||||
self.encoder = T5Stack(encoder_config, self.shared)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared = new_embeddings
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import T5Tokenizer, T5EncoderModel
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = T5EncoderModel.from_pretrained('t5-small')
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
|
||||
>>> outputs = model(input_ids=input_ids)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
|
@ -32,7 +32,7 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFSeq2SeqLMOutput, TFSeq2SeqModelOutput
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput
|
||||
from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
@ -949,6 +949,48 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
behaviors between training and evaluation).
|
||||
"""
|
||||
|
||||
T5_ENCODER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
inputs (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
||||
should be able to pad the inputs on the right or the left.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
||||
:func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for
|
||||
details.
|
||||
|
||||
To know more on how to prepare :obj:`inputs` for pre-training take a look at `T5 Training
|
||||
<./t5.html#training>`__.
|
||||
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||
behaviors between training and evaluation).
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
|
||||
@ -1385,3 +1427,115 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
|
||||
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
||||
return past + (reordered_decoder_past,)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare T5 Model transformer outputting encoder's raw hidden-states" "without any specific head on top.",
|
||||
T5_START_DOCSTRING,
|
||||
)
|
||||
class TFT5EncoderModel(TFT5PreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
|
||||
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import T5Tokenizer, TFT5Model
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = TFT5EncoderModel.from_pretrained('t5-small')
|
||||
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
|
||||
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
output_attentions = inputs["output_attentions"] if output_attentions else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if output_hidden_states else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if inputs["return_dict"] is not None else self.config.return_dict
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
head_mask=head_mask,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return encoder_outputs
|
||||
|
||||
if not cast_bool_to_primitive(output_hidden_states, self.config.output_hidden_states):
|
||||
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:]
|
||||
if not cast_bool_to_primitive(output_attentions, self.config.output_attentions):
|
||||
encoder_outputs = encoder_outputs + (None,)
|
||||
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1],
|
||||
attentions=encoder_outputs[2],
|
||||
)
|
||||
|
@ -1361,6 +1361,15 @@ def load_tf_weights_in_mobilebert(*args, **kwargs):
|
||||
requires_pytorch(load_tf_weights_in_mobilebert)
|
||||
|
||||
|
||||
class MT5EncoderModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MT5ForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
@ -1719,6 +1728,15 @@ class SqueezeBertPreTrainedModel:
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class T5EncoderModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class T5ForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
@ -997,6 +997,15 @@ class TFMobileBertPreTrainedModel:
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFMT5EncoderModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFMT5ForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
@ -1142,6 +1151,15 @@ class TFRobertaPreTrainedModel:
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFT5EncoderModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFT5ForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
@ -554,7 +554,7 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
||||
def config_and_inputs(self):
|
||||
question_encoder_tester = DPRModelTester(self)
|
||||
dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs()
|
||||
generator_tester = T5ModelTester(self, vocab_size=1100, n_positions=30)
|
||||
generator_tester = T5ModelTester(self, vocab_size=1100)
|
||||
t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
|
||||
|
||||
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
|
||||
|
@ -30,7 +30,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import T5Config, T5ForConditionalGeneration, T5Model, T5Tokenizer
|
||||
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer
|
||||
from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@ -39,7 +39,6 @@ class T5ModelTester:
|
||||
self,
|
||||
parent,
|
||||
vocab_size=99,
|
||||
n_positions=14,
|
||||
batch_size=13,
|
||||
encoder_seq_length=7,
|
||||
decoder_seq_length=9,
|
||||
@ -71,7 +70,6 @@ class T5ModelTester:
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@ -104,7 +102,6 @@ class T5ModelTester:
|
||||
|
||||
config = T5Config(
|
||||
vocab_size=self.vocab_size,
|
||||
n_positions=self.n_positions,
|
||||
d_model=self.hidden_size,
|
||||
d_ff=self.d_ff,
|
||||
d_kv=self.hidden_size // self.num_attention_heads,
|
||||
@ -559,6 +556,144 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class T5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
vocab_size=99,
|
||||
batch_size=13,
|
||||
encoder_seq_length=7,
|
||||
# For common tests
|
||||
use_attention_mask=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
d_ff=37,
|
||||
relative_attention_num_buckets=8,
|
||||
is_training=False,
|
||||
dropout_rate=0.1,
|
||||
initializer_factor=0.002,
|
||||
is_encoder_decoder=False,
|
||||
eos_token_id=1,
|
||||
pad_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.encoder_seq_length = encoder_seq_length
|
||||
# For common tests
|
||||
self.seq_length = self.encoder_seq_length
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_ff = d_ff
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_factor = initializer_factor
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.is_encoder_decoder = is_encoder_decoder
|
||||
self.scope = None
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
||||
|
||||
attention_mask = None
|
||||
if self.use_attention_mask:
|
||||
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
|
||||
|
||||
config = T5Config(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
d_ff=self.d_ff,
|
||||
d_kv=self.hidden_size // self.num_attention_heads,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_factor=self.initializer_factor,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
):
|
||||
model = T5EncoderModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
result = model(input_ids=input_ids)
|
||||
encoder_output = result.last_hidden_state
|
||||
|
||||
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_fp16_forward(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
):
|
||||
model = T5EncoderModel(config=config).to(torch_device).half().eval()
|
||||
output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (T5EncoderModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = T5EncoderOnlyModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
||||
def test_model_fp16_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||
|
||||
|
||||
def use_task_specific_params(model, task):
|
||||
model.config.update(model.config.task_specific_params[task])
|
||||
|
||||
|
@ -25,7 +25,7 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import T5Tokenizer, TFT5ForConditionalGeneration, TFT5Model
|
||||
from transformers import T5Tokenizer, TFT5EncoderModel, TFT5ForConditionalGeneration, TFT5Model
|
||||
|
||||
|
||||
class TFT5ModelTester:
|
||||
@ -295,6 +295,128 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class TFT5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
vocab_size=99,
|
||||
batch_size=13,
|
||||
encoder_seq_length=7,
|
||||
# For common tests
|
||||
use_attention_mask=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
d_ff=37,
|
||||
relative_attention_num_buckets=8,
|
||||
is_training=False,
|
||||
dropout_rate=0.1,
|
||||
initializer_factor=0.002,
|
||||
is_encoder_decoder=False,
|
||||
eos_token_id=1,
|
||||
pad_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.encoder_seq_length = encoder_seq_length
|
||||
# For common tests
|
||||
self.seq_length = self.encoder_seq_length
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_ff = d_ff
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_factor = initializer_factor
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.is_encoder_decoder = is_encoder_decoder
|
||||
self.scope = None
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
||||
|
||||
attention_mask = None
|
||||
if self.use_attention_mask:
|
||||
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
|
||||
|
||||
config = T5Config(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
d_ff=self.d_ff,
|
||||
d_kv=self.hidden_size // self.num_attention_heads,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_factor=self.initializer_factor,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
):
|
||||
model = TFT5EncoderModel(config=config)
|
||||
result = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
result = model(input_ids=input_ids)
|
||||
encoder_output = result.last_hidden_state
|
||||
|
||||
self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFT5EncoderOnlyModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
# is not able to be part of a pipeline
|
||||
def test_train_pipeline_custom_model(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
@ -85,6 +85,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"FlaubertForQuestionAnswering",
|
||||
"FunnelBaseModel",
|
||||
"GPT2DoubleHeadsModel",
|
||||
"MT5EncoderModel",
|
||||
"OpenAIGPTDoubleHeadsModel",
|
||||
"ProphetNetDecoder",
|
||||
"ProphetNetEncoder",
|
||||
@ -92,13 +93,16 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"RagSequenceForGeneration",
|
||||
"RagTokenForGeneration",
|
||||
"T5Stack",
|
||||
"T5EncoderModel",
|
||||
"TFDPRContextEncoder",
|
||||
"TFDPREncoder",
|
||||
"TFDPRReader",
|
||||
"TFDPRSpanPredictor",
|
||||
"TFFunnelBaseModel",
|
||||
"TFGPT2DoubleHeadsModel",
|
||||
"TFMT5EncoderModel",
|
||||
"TFOpenAIGPTDoubleHeadsModel",
|
||||
"TFT5EncoderModel",
|
||||
"XLMForQuestionAnswering",
|
||||
"XLMProphetNetDecoder",
|
||||
"XLMProphetNetEncoder",
|
||||
|
Loading…
Reference in New Issue
Block a user