mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
TF port of ESM (#19587)
* Partial TF port for ESM model * Add ESM-TF tests * Add the various imports for TF-ESM * TF weight conversion almost ready * Stop ignoring the decoder weights in PT * Add tests and lots of fixes * fix-copies * Fix imports, add model docs * Add get_vocab() to tokenizer * Fix vocab links for pretrained files * Allow multiple inputs with a sep * Use EOS as SEP token because ESM vocab lacks SEP * Correctly return special tokens mask from ESM tokenizer * make fixup * Stop testing unsupported embedding resizing * Handle TF bias correctly * Skip all models with slow tokenizers in the token classification test * Fixing the batch/unbatcher of pipelines to accomodate the `None` being passed around. * Fixing pipeline bug caused by slow tokenizer being different. * Update src/transformers/models/esm/modeling_tf_esm.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/esm/modeling_tf_esm.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/esm/modeling_tf_esm.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update set_input_embeddings and the copyright notices Co-authored-by: Your Name <you@example.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
d7754c43d0
commit
3b3024da70
@ -243,7 +243,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ |
|
| Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||||
| ERNIE | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| ERNIE | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| ESM | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| ESM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| FLAVA | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| FLAVA | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
@ -107,3 +107,23 @@ and [Matt](https://huggingface.co/Rocketknight1).
|
|||||||
|
|
||||||
[[autodoc]] EsmForTokenClassification
|
[[autodoc]] EsmForTokenClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## TFEsmModel
|
||||||
|
|
||||||
|
[[autodoc]] TFEsmModel
|
||||||
|
- call
|
||||||
|
|
||||||
|
## TFEsmForMaskedLM
|
||||||
|
|
||||||
|
[[autodoc]] TFEsmForMaskedLM
|
||||||
|
- call
|
||||||
|
|
||||||
|
## TFEsmForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] TFEsmForSequenceClassification
|
||||||
|
- call
|
||||||
|
|
||||||
|
## TFEsmForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] TFEsmForTokenClassification
|
||||||
|
- call
|
||||||
|
@ -2462,6 +2462,16 @@ else:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.encoder_decoder"].append("TFEncoderDecoderModel")
|
_import_structure["models.encoder_decoder"].append("TFEncoderDecoderModel")
|
||||||
|
_import_structure["models.esm"].extend(
|
||||||
|
[
|
||||||
|
"ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"TFEsmForMaskedLM",
|
||||||
|
"TFEsmForSequenceClassification",
|
||||||
|
"TFEsmForTokenClassification",
|
||||||
|
"TFEsmModel",
|
||||||
|
"TFEsmPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.flaubert"].extend(
|
_import_structure["models.flaubert"].extend(
|
||||||
[
|
[
|
||||||
"TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -5134,6 +5144,14 @@ if TYPE_CHECKING:
|
|||||||
TFElectraPreTrainedModel,
|
TFElectraPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.encoder_decoder import TFEncoderDecoderModel
|
from .models.encoder_decoder import TFEncoderDecoderModel
|
||||||
|
from .models.esm import (
|
||||||
|
ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFEsmForMaskedLM,
|
||||||
|
TFEsmForSequenceClassification,
|
||||||
|
TFEsmForTokenClassification,
|
||||||
|
TFEsmModel,
|
||||||
|
TFEsmPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.flaubert import (
|
from .models.flaubert import (
|
||||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFFlaubertForMultipleChoice,
|
TFFlaubertForMultipleChoice,
|
||||||
|
@ -47,6 +47,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("distilbert", "TFDistilBertModel"),
|
("distilbert", "TFDistilBertModel"),
|
||||||
("dpr", "TFDPRQuestionEncoder"),
|
("dpr", "TFDPRQuestionEncoder"),
|
||||||
("electra", "TFElectraModel"),
|
("electra", "TFElectraModel"),
|
||||||
|
("esm", "TFEsmModel"),
|
||||||
("flaubert", "TFFlaubertModel"),
|
("flaubert", "TFFlaubertModel"),
|
||||||
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
|
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
|
||||||
("gpt2", "TFGPT2Model"),
|
("gpt2", "TFGPT2Model"),
|
||||||
@ -129,6 +130,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
|||||||
("ctrl", "TFCTRLLMHeadModel"),
|
("ctrl", "TFCTRLLMHeadModel"),
|
||||||
("distilbert", "TFDistilBertForMaskedLM"),
|
("distilbert", "TFDistilBertForMaskedLM"),
|
||||||
("electra", "TFElectraForMaskedLM"),
|
("electra", "TFElectraForMaskedLM"),
|
||||||
|
("esm", "TFEsmForMaskedLM"),
|
||||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||||
("funnel", "TFFunnelForMaskedLM"),
|
("funnel", "TFFunnelForMaskedLM"),
|
||||||
("gpt2", "TFGPT2LMHeadModel"),
|
("gpt2", "TFGPT2LMHeadModel"),
|
||||||
@ -223,6 +225,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("deberta-v2", "TFDebertaV2ForMaskedLM"),
|
("deberta-v2", "TFDebertaV2ForMaskedLM"),
|
||||||
("distilbert", "TFDistilBertForMaskedLM"),
|
("distilbert", "TFDistilBertForMaskedLM"),
|
||||||
("electra", "TFElectraForMaskedLM"),
|
("electra", "TFElectraForMaskedLM"),
|
||||||
|
("esm", "TFEsmForMaskedLM"),
|
||||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||||
("funnel", "TFFunnelForMaskedLM"),
|
("funnel", "TFFunnelForMaskedLM"),
|
||||||
("layoutlm", "TFLayoutLMForMaskedLM"),
|
("layoutlm", "TFLayoutLMForMaskedLM"),
|
||||||
@ -273,6 +276,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
|
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
|
||||||
("distilbert", "TFDistilBertForSequenceClassification"),
|
("distilbert", "TFDistilBertForSequenceClassification"),
|
||||||
("electra", "TFElectraForSequenceClassification"),
|
("electra", "TFElectraForSequenceClassification"),
|
||||||
|
("esm", "TFEsmForSequenceClassification"),
|
||||||
("flaubert", "TFFlaubertForSequenceClassification"),
|
("flaubert", "TFFlaubertForSequenceClassification"),
|
||||||
("funnel", "TFFunnelForSequenceClassification"),
|
("funnel", "TFFunnelForSequenceClassification"),
|
||||||
("gpt2", "TFGPT2ForSequenceClassification"),
|
("gpt2", "TFGPT2ForSequenceClassification"),
|
||||||
@ -346,6 +350,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("deberta-v2", "TFDebertaV2ForTokenClassification"),
|
("deberta-v2", "TFDebertaV2ForTokenClassification"),
|
||||||
("distilbert", "TFDistilBertForTokenClassification"),
|
("distilbert", "TFDistilBertForTokenClassification"),
|
||||||
("electra", "TFElectraForTokenClassification"),
|
("electra", "TFElectraForTokenClassification"),
|
||||||
|
("esm", "TFEsmForTokenClassification"),
|
||||||
("flaubert", "TFFlaubertForTokenClassification"),
|
("flaubert", "TFFlaubertForTokenClassification"),
|
||||||
("funnel", "TFFunnelForTokenClassification"),
|
("funnel", "TFFunnelForTokenClassification"),
|
||||||
("layoutlm", "TFLayoutLMForTokenClassification"),
|
("layoutlm", "TFLayoutLMForTokenClassification"),
|
||||||
|
@ -122,6 +122,7 @@ else:
|
|||||||
),
|
),
|
||||||
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
|
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
("esm", ("EsmTokenizer", None)),
|
||||||
("flaubert", ("FlaubertTokenizer", None)),
|
("flaubert", ("FlaubertTokenizer", None)),
|
||||||
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("fsmt", ("FSMTTokenizer", None)),
|
("fsmt", ("FSMTTokenizer", None)),
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@ -40,6 +40,21 @@ else:
|
|||||||
"EsmPreTrainedModel",
|
"EsmPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_tf_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_import_structure["modeling_tf_esm"] = [
|
||||||
|
"TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"TFEsmForMaskedLM",
|
||||||
|
"TFEsmForSequenceClassification",
|
||||||
|
"TFEsmForTokenClassification",
|
||||||
|
"TFEsmModel",
|
||||||
|
"TFEsmPreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig
|
from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig
|
||||||
@ -60,6 +75,21 @@ if TYPE_CHECKING:
|
|||||||
EsmPreTrainedModel,
|
EsmPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_tf_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from .modeling_tf_esm import (
|
||||||
|
TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFEsmForMaskedLM,
|
||||||
|
TFEsmForSequenceClassification,
|
||||||
|
TFEsmForTokenClassification,
|
||||||
|
TFEsmModel,
|
||||||
|
TFEsmPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2022 Facebook and The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -42,12 +42,14 @@ from .configuration_esm import EsmConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "facebook/esm-1b"
|
_CHECKPOINT_FOR_DOC = "Rocketknight1/esm2_t6_8M_UR50D"
|
||||||
_CONFIG_FOR_DOC = "EsmConfig"
|
_CONFIG_FOR_DOC = "EsmConfig"
|
||||||
_TOKENIZER_FOR_DOC = "EsmTokenizer"
|
_TOKENIZER_FOR_DOC = "EsmTokenizer"
|
||||||
|
|
||||||
ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"facebook/esm-1b",
|
"Rocketknight1/esm2_t6_8M_UR50D",
|
||||||
|
"Rocketknight1/esm2_t12_35M_UR50D",
|
||||||
|
# This is not a complete list of all ESM models!
|
||||||
# See all ESM models at https://huggingface.co/models?filter=esm
|
# See all ESM models at https://huggingface.co/models?filter=esm
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -115,7 +117,6 @@ class EsmEmbeddings(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
||||||
|
|
||||||
if config.emb_layer_norm_before:
|
if config.emb_layer_norm_before:
|
||||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
@ -658,15 +659,6 @@ class EsmPreTrainedModel(PreTrainedModel):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
module.weight.data.fill_(1.0)
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
def update_keys_to_ignore(self, config, del_keys_to_ignore):
|
|
||||||
"""Remove some keys from ignore list"""
|
|
||||||
if not config.tie_word_embeddings:
|
|
||||||
# must make a new list, or the class variable gets modified!
|
|
||||||
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
|
|
||||||
self._keys_to_ignore_on_load_missing = [
|
|
||||||
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
ESM_START_DOCSTRING = r"""
|
ESM_START_DOCSTRING = r"""
|
||||||
|
|
||||||
@ -907,8 +899,7 @@ class EsmModel(EsmPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
|
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
|
||||||
class EsmForMaskedLM(EsmPreTrainedModel):
|
class EsmForMaskedLM(EsmPreTrainedModel):
|
||||||
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
|
||||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@ -923,9 +914,6 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
|||||||
self.esm = EsmModel(config, add_pooling_layer=False)
|
self.esm = EsmModel(config, add_pooling_layer=False)
|
||||||
self.lm_head = EsmLMHead(config)
|
self.lm_head = EsmLMHead(config)
|
||||||
|
|
||||||
# The LM head weights require special treatment only when they are tied with the word embeddings
|
|
||||||
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
|
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
@ -944,17 +932,17 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
@ -1009,17 +997,13 @@ class EsmLMHead(nn.Module):
|
|||||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
||||||
self.decoder.bias = self.bias
|
|
||||||
|
|
||||||
def forward(self, features, **kwargs):
|
def forward(self, features, **kwargs):
|
||||||
x = self.dense(features)
|
x = self.dense(features)
|
||||||
x = gelu(x)
|
x = gelu(x)
|
||||||
x = self.layer_norm(x)
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
# project back to size of vocabulary with bias
|
# project back to size of vocabulary with bias
|
||||||
x = self.decoder(x)
|
x = self.decoder(x) + self.bias
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -1052,15 +1036,15 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
@ -1148,15 +1132,15 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
1399
src/transformers/models/esm/modeling_tf_esm.py
Normal file
1399
src/transformers/models/esm/modeling_tf_esm.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright Facebook and The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -27,12 +27,18 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
|||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
"vocab_file": {
|
"vocab_file": {
|
||||||
"facebook/esm1b": "https://huggingface.co/facebook/esm1b/resolve/main/vocab.txt",
|
"Rocketknight1/esm2_t6_8M_UR50D": (
|
||||||
|
"https://huggingface.co/Rocketknight1/esm2_t6_8M_UR50D/resolve/main/vocab.txt"
|
||||||
|
),
|
||||||
|
"Rocketknight1/esm2_t12_35M_UR50D": (
|
||||||
|
"https://huggingface.co/Rocketknight1/esm2_t12_35M_UR50D/resolve/main/vocab.txt"
|
||||||
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
"facebook/esm1b": 1024,
|
"Rocketknight1/esm2_t6_8M_UR50D": 1024,
|
||||||
|
"Rocketknight1/esm2_t12_35M_UR50D": 1024,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -77,6 +83,9 @@ class EsmTokenizer(PreTrainedTokenizer):
|
|||||||
def get_vocab_size(self, with_added_tokens=False):
|
def get_vocab_size(self, with_added_tokens=False):
|
||||||
return len(self._id_to_token)
|
return len(self._id_to_token)
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
return {token: i for i, token in enumerate(self.all_tokens)}
|
||||||
|
|
||||||
def token_to_id(self, token: str) -> int:
|
def token_to_id(self, token: str) -> int:
|
||||||
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
|
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
|
||||||
|
|
||||||
@ -86,11 +95,42 @@ class EsmTokenizer(PreTrainedTokenizer):
|
|||||||
def build_inputs_with_special_tokens(
|
def build_inputs_with_special_tokens(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return [self.cls_token_id] + token_ids_0 + [self.eos_token_id]
|
||||||
|
cls = [self.cls_token_id]
|
||||||
|
sep = [self.eos_token_id] # No sep token in ESM vocabulary
|
||||||
|
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||||
|
|
||||||
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
|
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (`List[int]`):
|
||||||
|
List of ids of the first sequence.
|
||||||
|
token_ids_1 (`List[int]`, *optional*):
|
||||||
|
List of ids of the second sequence.
|
||||||
|
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the token list is already formatted with special tokens for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
|
"""
|
||||||
|
if already_has_special_tokens:
|
||||||
|
if token_ids_1 is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You should not supply a second sequence if the provided sequence of "
|
||||||
|
"ids is already formatted with special tokens for the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
return [1 if token in self.all_special_ids else 0 for token in token_ids_0]
|
||||||
|
mask = [1] + ([0] * len(token_ids_0)) + [1]
|
||||||
if token_ids_1 is not None:
|
if token_ids_1 is not None:
|
||||||
raise ValueError("Multiple input sentences are not supported!")
|
mask += [0] * len(token_ids_1) + [1]
|
||||||
cls_: List[int] = [self.cls_token_id]
|
return mask
|
||||||
eos_: List[int] = [self.eos_token_id]
|
|
||||||
return cls_ + token_ids_0 + eos_
|
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory, filename_prefix):
|
def save_vocabulary(self, save_directory, filename_prefix):
|
||||||
vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt")
|
vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt")
|
||||||
|
@ -138,7 +138,7 @@ class FillMaskPipeline(Pipeline):
|
|||||||
# For multi masks though, the other [MASK] would be removed otherwise
|
# For multi masks though, the other [MASK] would be removed otherwise
|
||||||
# making the output look odd, so we add them back
|
# making the output look odd, so we add them back
|
||||||
sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
|
sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
|
||||||
proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode(p), "sequence": sequence}
|
proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode([p]), "sequence": sequence}
|
||||||
row.append(proposition)
|
row.append(proposition)
|
||||||
result.append(row)
|
result.append(row)
|
||||||
if single_mask:
|
if single_mask:
|
||||||
|
@ -83,7 +83,10 @@ class PipelineIterator(IterableDataset):
|
|||||||
elif isinstance(element[0], np.ndarray):
|
elif isinstance(element[0], np.ndarray):
|
||||||
loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
|
loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
|
||||||
continue
|
continue
|
||||||
if isinstance(element[self._loader_batch_index], torch.Tensor):
|
if element is None:
|
||||||
|
# This can happen for optional data that get passed around
|
||||||
|
loader_batched[k] = None
|
||||||
|
elif isinstance(element[self._loader_batch_index], torch.Tensor):
|
||||||
# Take correct batch data, but make it looked like batch_size=1
|
# Take correct batch data, but make it looked like batch_size=1
|
||||||
# For compatibility with other methods within transformers
|
# For compatibility with other methods within transformers
|
||||||
|
|
||||||
|
@ -1142,6 +1142,44 @@ class TFEncoderDecoderModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
ESM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class TFEsmForMaskedLM(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFEsmForSequenceClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFEsmForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFEsmModel(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFEsmPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -240,6 +240,14 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(position_ids.shape, expected_positions.shape)
|
self.assertEqual(position_ids.shape, expected_positions.shape)
|
||||||
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
||||||
|
|
||||||
|
@unittest.skip("Esm does not support embedding resizing")
|
||||||
|
def test_resize_embeddings_untied(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Esm does not support embedding resizing")
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class EsmModelIntegrationTest(TestCasePlus):
|
class EsmModelIntegrationTest(TestCasePlus):
|
||||||
@ -270,24 +278,3 @@ class EsmModelIntegrationTest(TestCasePlus):
|
|||||||
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
|
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
def test_lm_head_ignore_keys(self):
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
|
||||||
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
|
|
||||||
config = EsmConfig.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
|
||||||
config_tied = deepcopy(config)
|
|
||||||
config_tied.tie_word_embeddings = True
|
|
||||||
config_untied = deepcopy(config)
|
|
||||||
config_untied.tie_word_embeddings = False
|
|
||||||
for cls in [EsmForMaskedLM]:
|
|
||||||
model = cls(config_tied)
|
|
||||||
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)
|
|
||||||
|
|
||||||
# the keys should be different when embeddings aren't tied
|
|
||||||
model = cls(config_untied)
|
|
||||||
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)
|
|
||||||
|
|
||||||
# test that saving works with updated ignore keys - just testing that it doesn't fail
|
|
||||||
model.save_pretrained(self.get_auto_remove_tmp_dir())
|
|
||||||
|
287
tests/models/esm/test_modeling_tf_esm.py
Normal file
287
tests/models/esm/test_modeling_tf_esm.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import EsmConfig, is_tf_available
|
||||||
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import numpy
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from transformers.models.esm.modeling_tf_esm import (
|
||||||
|
TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFEsmForMaskedLM,
|
||||||
|
TFEsmForSequenceClassification,
|
||||||
|
TFEsmForTokenClassification,
|
||||||
|
TFEsmModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from tests.test_modeling_tf_roberta
|
||||||
|
class TFEsmModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = 13
|
||||||
|
self.seq_length = 7
|
||||||
|
self.is_training = True
|
||||||
|
self.use_input_mask = True
|
||||||
|
self.use_labels = True
|
||||||
|
self.vocab_size = 99
|
||||||
|
self.hidden_size = 32
|
||||||
|
self.num_hidden_layers = 5
|
||||||
|
self.num_attention_heads = 4
|
||||||
|
self.intermediate_size = 37
|
||||||
|
self.hidden_act = "gelu"
|
||||||
|
self.hidden_dropout_prob = 0.1
|
||||||
|
self.attention_probs_dropout_prob = 0.1
|
||||||
|
self.max_position_embeddings = 512
|
||||||
|
self.type_vocab_size = 16
|
||||||
|
self.type_sequence_label_size = 2
|
||||||
|
self.initializer_range = 0.02
|
||||||
|
self.num_labels = 3
|
||||||
|
self.num_choices = 4
|
||||||
|
self.scope = None
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = EsmConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
pad_token_id=1,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
type_vocab_size=self.type_vocab_size,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config.is_decoder = True
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
|
model = TFEsmModel(config=config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
|
result = model(inputs)
|
||||||
|
|
||||||
|
inputs = [input_ids, input_mask]
|
||||||
|
result = model(inputs)
|
||||||
|
|
||||||
|
result = model(input_ids)
|
||||||
|
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_model_as_decoder(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
config.add_cross_attention = True
|
||||||
|
|
||||||
|
model = TFEsmModel(config=config)
|
||||||
|
inputs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
|
"encoder_attention_mask": encoder_attention_mask,
|
||||||
|
}
|
||||||
|
result = model(inputs)
|
||||||
|
|
||||||
|
inputs = [input_ids, input_mask]
|
||||||
|
result = model(inputs, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
# Also check the case where encoder outputs are not passed
|
||||||
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_for_masked_lm(
|
||||||
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = TFEsmForMaskedLM(config=config)
|
||||||
|
result = model([input_ids, input_mask])
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_for_token_classification(
|
||||||
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = TFEsmForTokenClassification(config=config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
|
result = model(inputs)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
TFEsmModel,
|
||||||
|
TFEsmForMaskedLM,
|
||||||
|
TFEsmForSequenceClassification,
|
||||||
|
TFEsmForTokenClassification,
|
||||||
|
)
|
||||||
|
if is_tf_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = TFEsmModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=EsmConfig, hidden_size=37)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
"""Test the base model"""
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_model_as_decoder(self):
|
||||||
|
"""Test the base model as a decoder (of an encoder-decoder architecture)
|
||||||
|
|
||||||
|
is_deocder=True + cross_attention + pass encoder outputs
|
||||||
|
"""
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_masked_lm(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_token_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
for model_name in TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
model = TFEsmModel.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@unittest.skip("Protein models do not support embedding resizing.")
|
||||||
|
def test_resize_token_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Protein models do not support embedding resizing.")
|
||||||
|
def test_save_load_after_resize_token_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_inference_masked_lm(self):
|
||||||
|
model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||||
|
|
||||||
|
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
|
||||||
|
output = model(input_ids)[0]
|
||||||
|
expected_shape = [1, 6, 33]
|
||||||
|
self.assertEqual(list(output.numpy().shape), expected_shape)
|
||||||
|
# compare the actual values for a slice.
|
||||||
|
expected_slice = tf.constant(
|
||||||
|
[[[15.0963, -6.6414, -1.1346], [-0.2209, -9.9633, 4.2082], [-1.6045, -10.0011, 1.5882]]]
|
||||||
|
)
|
||||||
|
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_no_head(self):
|
||||||
|
model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||||
|
|
||||||
|
input_ids = tf.constant([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||||
|
output = model(input_ids)[0]
|
||||||
|
# compare the actual values for a slice.
|
||||||
|
expected_slice = tf.constant(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.144337, 0.541198, 0.32479298],
|
||||||
|
[0.30328932, 0.00519154, 0.31089523],
|
||||||
|
[0.32273883, -0.24992886, 0.34143737],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
|
@ -44,6 +44,8 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
def run_pipeline_test(self, token_classifier, _):
|
def run_pipeline_test(self, token_classifier, _):
|
||||||
model = token_classifier.model
|
model = token_classifier.model
|
||||||
tokenizer = token_classifier.tokenizer
|
tokenizer = token_classifier.tokenizer
|
||||||
|
if not tokenizer.is_fast:
|
||||||
|
return # Slow tokenizers do not return offsets mappings, so this test will fail
|
||||||
|
|
||||||
outputs = token_classifier("A simple string")
|
outputs = token_classifier("A simple string")
|
||||||
self.assertIsInstance(outputs, list)
|
self.assertIsInstance(outputs, list)
|
||||||
|
Loading…
Reference in New Issue
Block a user