mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Add TFRag (#9002)
* Create modeling_tf_dpr.py * Add TFDPR * Add back TFPegasus, TFMarian, TFMBart, TFBlenderBot last commit accidentally deleted these 4 lines, so I recover them back * Add TFDPR * Add TFDPR * clean up some comments, add TF input-style doc string * Add TFDPR * Make return_dict=False as default * Fix return_dict bug (in .from_pretrained) * Add get_input_embeddings() * Create test_modeling_tf_dpr.py The current version is already passed all 27 tests! Please see the test run at : https://colab.research.google.com/drive/1czS_m9zy5k-iSJbzA_DP1k1xAAC_sdkf?usp=sharing * fix quality * delete init weights * run fix copies * fix repo consis * del config_class, load_tf_weights They shoud be 'pytorch only' * add config_class back after removing it, test failed ... so totally only removing "use_tf_weights = None" on Lysandre suggestion * newline after .. note:: * import tf, np (Necessary for ModelIntegrationTest) * slow_test from_pretrained with from_pt=True At the moment we don't have TF weights (since we don't have official official TF model) Previously, I did not run slow test, so I missed this bug * Add simple TFDPRModelIntegrationTest Note that this is just a test that TF and Pytorch gives approx. the same output. However, I could not test with the official DPR repo's output yet * upload correct tf model * remove position_ids as missing keys * create modeling_tf_rag * add tests for tf * add tf tests * revert wrong pt commit * further refactor * further refactor * refactor * Update modeling_tf_rag.py - input_processing - fix prepare_input_for_generation (mostly fix generate bug) - bring back from_pretrained hack in order to test generate * delete colab pieces of code * Show case of greedy "generate" Temporarily change from beam_search test to greedy_search test to show case that TF and PT do get equivalent output. * cosmetic update * correct typos * update * push some progress * make easy check * fix rag save from pretrained * Update src/transformers/modeling_tf_utils.py * remove commented out lines * delete unnecessary lines * add simple test case for nq_checkpoint Add nq_checkpoint test to show that current version without hack still fails * temporarily put ugly hack back again * Add TFRagSequenceForGeneration!! * __init__.py , import TFRagSequenceForGeneration * Add TFRagSequence tests! * rag init.py - add TFRagSequenceForGeneration * fix from_pretrained * fix prepare_inputs_for_generation * Beam search for RagToken! * minor clean up * add tf.cast in TFRagModel * More tf.cast * Add all remaining tests (still have issues) * delete all T5 related * make style * fix load weight prefix * fix bart * fix return_dict for tf_rag make all tests pass .. Hooray * fix some tests * fix code quality * fix qualtiy check * finish tests tf rag * add tf rag to docs * remove TFT5 from docstring Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * remove TFT5 from docstring Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Delete outdated comments Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * improve doc strings * add generative model classes * fix adjust token logic * refactor generate for TFRag * using shape_list, not _get_shape Co-authored-by: Julien Plu <plu.julien@gmail.com> * axis=[1]->axis=1 * delete NEED_HELP comment * improve readability Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * improve readability Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * improve readability Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Indicating model is in a developing state in docstrings As suggested by Julien * small last changes * apply sylvains suggestions * finish tf rag Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: patrickvonplaten <patrick@huggingface.co> Co-authored-by: Julien Plu <plu.julien@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
3ced9b3eb9
commit
696e8a4365
@ -296,7 +296,7 @@ TensorFlow and/or Flax.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| RAG | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
@ -94,3 +94,24 @@ RagTokenForGeneration
|
||||
|
||||
.. autoclass:: transformers.RagTokenForGeneration
|
||||
:members: forward, generate
|
||||
|
||||
|
||||
TFRagModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFRagModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFRagSequenceForGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFRagSequenceForGeneration
|
||||
:members: call, generate
|
||||
|
||||
|
||||
TFRagTokenForGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFRagTokenForGeneration
|
||||
:members: call, generate
|
||||
|
@ -1130,6 +1130,13 @@ if is_tf_available():
|
||||
]
|
||||
)
|
||||
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"])
|
||||
_import_structure["models.rag"].extend(
|
||||
[
|
||||
"TFRagModel",
|
||||
"TFRagSequenceForGeneration",
|
||||
"TFRagTokenForGeneration",
|
||||
]
|
||||
)
|
||||
_import_structure["models.roberta"].extend(
|
||||
[
|
||||
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -2166,6 +2173,7 @@ if TYPE_CHECKING:
|
||||
TFOpenAIGPTPreTrainedModel,
|
||||
)
|
||||
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
|
||||
from .models.rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
|
||||
from .models.roberta import (
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFRobertaForMaskedLM,
|
||||
|
@ -441,6 +441,7 @@ class TFGenerationMixin:
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated
|
||||
@ -455,7 +456,7 @@ class TFGenerationMixin:
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
|
||||
)
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
@ -609,6 +610,7 @@ class TFGenerationMixin:
|
||||
use_cache,
|
||||
forced_bos_token_id,
|
||||
forced_eos_token_id,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate sequences for each example with beam search."""
|
||||
|
||||
@ -637,7 +639,7 @@ class TFGenerationMixin:
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
|
||||
)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
@ -447,7 +447,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
||||
return output
|
||||
|
||||
|
||||
def load_tf_weights(model, resolved_archive_file):
|
||||
def load_tf_weights(model, resolved_archive_file, _prefix=None):
|
||||
"""
|
||||
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
|
||||
|
||||
@ -493,6 +493,10 @@ def load_tf_weights(model, resolved_archive_file):
|
||||
for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
|
||||
# TF names always start with the model name so we ignore it
|
||||
name = "/".join(weight_name.split("/")[1:])
|
||||
|
||||
if _prefix is not None:
|
||||
name = _prefix + "/" + name
|
||||
|
||||
saved_weights[name] = np.asarray(h5_layer_object[weight_name])
|
||||
|
||||
# Add the updated name to the final list for computing missing/unexpected values
|
||||
@ -501,7 +505,14 @@ def load_tf_weights(model, resolved_archive_file):
|
||||
# Loop over each weights from the instantiated model and compare with the weights from the H5 file
|
||||
for symbolic_weight in symbolic_weights:
|
||||
# TF names always start with the model name so we ignore it
|
||||
symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
|
||||
if _prefix is not None:
|
||||
delimeter = len(_prefix.split("/"))
|
||||
symbolic_weight_name = "/".join(
|
||||
symbolic_weight.name.split("/")[:delimeter]
|
||||
+ symbolic_weight.name.split("/")[delimeter + 1 :]
|
||||
)
|
||||
else:
|
||||
symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
|
||||
|
||||
# here we check if the current weight is among the weights from the H5 file
|
||||
# If yes, get the weight_value of the corresponding weight from the H5 file
|
||||
@ -603,6 +614,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
_keys_to_ignore_on_load_unexpected = None
|
||||
_requires_load_weight_prefix = False
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
@ -741,10 +753,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
|
||||
def get_prefix_bias_name(self) -> Union[None, str]:
|
||||
"""
|
||||
Get the concatenated prefix name of the bias from the model name to the parent layer
|
||||
Get the concatenated _prefix name of the bias from the model name to the parent layer
|
||||
|
||||
Return:
|
||||
:obj:`str`: The prefix name of the bias.
|
||||
:obj:`str`: The _prefix name of the bias.
|
||||
"""
|
||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return None
|
||||
@ -1052,7 +1064,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
:func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
|
||||
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
|
||||
@ -1151,6 +1163,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
@ -1230,6 +1243,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
# composed models, *e.g.* TFRag, require special treatment when it comes to loading
|
||||
# pre-trained weights.
|
||||
if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None:
|
||||
model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name")
|
||||
|
||||
# Instantiate model.
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
@ -1239,13 +1257,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
# Load from a PyTorch checkpoint
|
||||
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
|
||||
|
||||
model(model.dummy_inputs) # build the network with dummy inputs
|
||||
# we might need to extend the variable scope for composite models
|
||||
if load_weight_prefix is not None:
|
||||
with tf.compat.v1.variable_scope(load_weight_prefix):
|
||||
model(model.dummy_inputs) # build the network with dummy inputs
|
||||
else:
|
||||
model(model.dummy_inputs) # build the network with dummy inputs
|
||||
|
||||
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
try:
|
||||
missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file)
|
||||
missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file, load_weight_prefix)
|
||||
except OSError:
|
||||
raise OSError(
|
||||
"Unable to load weights from h5 file. "
|
||||
|
@ -553,7 +553,7 @@ class TFAutoModel(object):
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(TF_MODEL_MAPPING, use_model_types=False)
|
||||
def from_config(cls, config):
|
||||
def from_config(cls, config, **kwargs):
|
||||
r"""
|
||||
Instantiates one of the base model classes of the library from a configuration.
|
||||
|
||||
@ -575,7 +575,7 @@ class TFAutoModel(object):
|
||||
>>> model = TFAutoModel.from_config(config)
|
||||
"""
|
||||
if type(config) in TF_MODEL_MAPPING.keys():
|
||||
return TF_MODEL_MAPPING[type(config)](config)
|
||||
return TF_MODEL_MAPPING[type(config)](config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
@ -1037,7 +1037,7 @@ class TFAutoModelForSeq2SeqLM:
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, use_model_types=False)
|
||||
def from_config(cls, config):
|
||||
def from_config(cls, config, **kwargs):
|
||||
r"""
|
||||
Instantiates one of the model classes of the library---with a sequence-to-sequence language modeling
|
||||
head---from a configuration.
|
||||
@ -1061,7 +1061,7 @@ class TFAutoModelForSeq2SeqLM:
|
||||
>>> model = TFAutoModelForSeq2SeqLM.from_config(config)
|
||||
"""
|
||||
if type(config) in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys():
|
||||
return TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)](config)
|
||||
return TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)](config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
|
@ -1015,13 +1015,16 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
class TFBartMainLayer(tf.keras.layers.Layer):
|
||||
config_class = BartConfig
|
||||
|
||||
def __init__(self, config: BartConfig, **kwargs):
|
||||
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
# set tf scope correctly
|
||||
if load_weight_prefix is None:
|
||||
load_weight_prefix = "model.shared"
|
||||
|
||||
with tf.compat.v1.variable_scope(load_weight_prefix) as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
@ -1157,10 +1160,13 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class TFBartModel(TFBartPretrainedModel):
|
||||
def __init__(self, config: BartConfig, *inputs, **kwargs):
|
||||
|
||||
_requires_load_weight_prefix = True
|
||||
|
||||
def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.model = TFBartMainLayer(config, name="model")
|
||||
self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
@ -1263,9 +1269,11 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
r"model.decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
_requires_load_weight_prefix = True
|
||||
|
||||
def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model = TFBartMainLayer(config, name="model")
|
||||
self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
|
@ -18,7 +18,7 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_torch_available
|
||||
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -30,6 +30,9 @@ _import_structure = {
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_rag import RagConfig
|
||||
@ -39,6 +42,9 @@ if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
|
1832
src/transformers/models/rag/modeling_tf_rag.py
Normal file
1832
src/transformers/models/rag/modeling_tf_rag.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1332,6 +1332,25 @@ class TFPegasusModel:
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFRagModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFRagSequenceForGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFRagTokenForGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
1102
tests/test_modeling_tf_rag.py
Normal file
1102
tests/test_modeling_tf_rag.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -121,6 +121,9 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"TFGPT2DoubleHeadsModel",
|
||||
"TFMT5EncoderModel",
|
||||
"TFOpenAIGPTDoubleHeadsModel",
|
||||
"TFRagModel",
|
||||
"TFRagSequenceForGeneration",
|
||||
"TFRagTokenForGeneration",
|
||||
"TFT5EncoderModel",
|
||||
"Wav2Vec2ForCTC",
|
||||
"XLMForQuestionAnswering",
|
||||
|
Loading…
Reference in New Issue
Block a user