mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 02:28:24 +06:00

* Rebase ESM PR and update all file formats * Fix test relative imports * Add __init__.py to the test dir * Disable gradient checkpointing * Remove references to TFESM... FOR NOW >:| * Remove completed TODOs from tests * Convert docstrings to mdx, fix-copies from BERT * fix-copies for the README and index * Update ESM's __init__.py to the modern format * Add to _toctree.yml * Ensure we correctly copy the pad_token_id from the original ESM model * Ensure we correctly copy the pad_token_id from the original ESM model * Tiny grammar nitpicks * Make the layer norm after embeddings an optional flag * Make the layer norm after embeddings an optional flag * Update the conversion script to handle other model classes * Remove token_type_ids entirely, fix attention_masking and add checks to convert_esm.py * Break the copied from link from BertModel.forward to remove token_type_ids * Remove debug array saves * Begin ESM-2 porting * Add a hacky workaround for the precision issue in original repo * Code cleanup * Remove unused checkpoint conversion code * Remove unused checkpoint conversion code * Fix copyright notices * Get rid of all references to the TF weights conversion * Remove token_type_ids from the tests * Fix test code * Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add credit * Remove _ args and __ kwargs in rotary embedding * Assertively remove asserts * Replace einsum with torch.outer() * Fix docstring formatting * Remove assertions in tokenization * Add paper citation to ESMModel docstring * Move vocab list to single line * Remove ESMLayer from init * Add Facebook copyrights * Clean up RotaryEmbedding docstring * Fix docstring formatting * Fix docstring for config object * Add explanation for new config methods * make fix-copies * Rename all the ESM- classes to Esm- * Update conversion script to allow pushing to hub * Update tests to point at my repo for now * Set config properly for tests * Remove the gross hack that forced loss of precision in inv_freq and instead copy the data from the model being converted * make fixup * Update expected values for slow tests * make fixup * Remove EsmForCausalLM for now * Remove EsmForCausalLM for now * Fix padding idx test * Updated README and docs with ESM-1b and ESM-2 separately (#19221) * Updated README and docs with ESM-1b and ESM-2 separately * Update READMEs, longer entry with 3 citations * make fix-copies Co-authored-by: Your Name <you@example.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Tom Sercu <tsercu@fb.com> Co-authored-by: Your Name <you@example.com>
294 lines
12 KiB
Python
294 lines
12 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Testing suite for the PyTorch ESM model. """
|
|
|
|
|
|
import unittest
|
|
|
|
from transformers import EsmConfig, is_torch_available
|
|
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
|
|
|
|
from ...generation.test_generation_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import EsmForMaskedLM, EsmForSequenceClassification, EsmForTokenClassification, EsmModel
|
|
from transformers.models.esm.modeling_esm import (
|
|
ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
EsmEmbeddings,
|
|
create_position_ids_from_input_ids,
|
|
)
|
|
|
|
|
|
# copied from tests.test_modeling_roberta
|
|
class EsmModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = 13
|
|
self.seq_length = 7
|
|
self.is_training = False
|
|
self.use_input_mask = True
|
|
self.use_token_type_ids = False
|
|
self.use_labels = True
|
|
self.vocab_size = 99
|
|
self.hidden_size = 32
|
|
self.num_hidden_layers = 5
|
|
self.num_attention_heads = 4
|
|
self.intermediate_size = 37
|
|
self.hidden_act = "gelu"
|
|
self.hidden_dropout_prob = 0.1
|
|
self.attention_probs_dropout_prob = 0.1
|
|
self.max_position_embeddings = 512
|
|
self.type_vocab_size = 16
|
|
self.type_sequence_label_size = 2
|
|
self.initializer_range = 0.02
|
|
self.num_labels = 3
|
|
self.num_choices = 4
|
|
self.scope = None
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
sequence_labels = None
|
|
token_labels = None
|
|
choice_labels = None
|
|
if self.use_labels:
|
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
|
|
def get_config(self):
|
|
return EsmConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
pad_token_id=1,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
hidden_act=self.hidden_act,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
type_vocab_size=self.type_vocab_size,
|
|
initializer_range=self.initializer_range,
|
|
)
|
|
|
|
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
|
model = EsmModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
result = model(input_ids)
|
|
result = model(input_ids)
|
|
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
|
|
|
def create_and_check_for_masked_lm(
|
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = EsmForMaskedLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def create_and_check_for_token_classification(
|
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
config.num_labels = self.num_labels
|
|
model = EsmForTokenClassification(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
input_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|
|
|
test_mismatched_shapes = False
|
|
|
|
all_model_classes = (
|
|
(
|
|
EsmForMaskedLM,
|
|
EsmModel,
|
|
EsmForSequenceClassification,
|
|
EsmForTokenClassification,
|
|
)
|
|
if is_torch_available()
|
|
else ()
|
|
)
|
|
all_generative_model_classes = ()
|
|
test_sequence_classification_problem_types = True
|
|
|
|
def setUp(self):
|
|
self.model_tester = EsmModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=EsmConfig, hidden_size=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_model_various_embeddings(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
config_and_inputs[0].position_embedding_type = type
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_for_masked_lm(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
|
|
|
def test_for_token_classification(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
for model_name in ESM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
|
model = EsmModel.from_pretrained(model_name)
|
|
self.assertIsNotNone(model)
|
|
|
|
def test_create_position_ids_respects_padding_index(self):
|
|
"""Ensure that the default position ids only assign a sequential . This is a regression
|
|
test for https://github.com/huggingface/transformers/issues/1761
|
|
|
|
The position ids should be masked with the embedding object's padding index. Therefore, the
|
|
first available non-padding position index is EsmEmbeddings.padding_idx + 1
|
|
"""
|
|
config = self.model_tester.prepare_config_and_inputs()[0]
|
|
model = EsmEmbeddings(config=config)
|
|
|
|
input_ids = torch.as_tensor([[12, 31, 13, model.padding_idx]])
|
|
expected_positions = torch.as_tensor(
|
|
[
|
|
[
|
|
0 + model.padding_idx + 1,
|
|
1 + model.padding_idx + 1,
|
|
2 + model.padding_idx + 1,
|
|
model.padding_idx,
|
|
]
|
|
]
|
|
)
|
|
position_ids = create_position_ids_from_input_ids(input_ids, model.padding_idx)
|
|
self.assertEqual(position_ids.shape, expected_positions.shape)
|
|
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
|
|
|
def test_create_position_ids_from_inputs_embeds(self):
|
|
"""Ensure that the default position ids only assign a sequential . This is a regression
|
|
test for https://github.com/huggingface/transformers/issues/1761
|
|
|
|
The position ids should be masked with the embedding object's padding index. Therefore, the
|
|
first available non-padding position index is EsmEmbeddings.padding_idx + 1
|
|
"""
|
|
config = self.model_tester.prepare_config_and_inputs()[0]
|
|
embeddings = EsmEmbeddings(config=config)
|
|
|
|
inputs_embeds = torch.empty(2, 4, 30)
|
|
expected_single_positions = [
|
|
0 + embeddings.padding_idx + 1,
|
|
1 + embeddings.padding_idx + 1,
|
|
2 + embeddings.padding_idx + 1,
|
|
3 + embeddings.padding_idx + 1,
|
|
]
|
|
expected_positions = torch.as_tensor([expected_single_positions, expected_single_positions])
|
|
position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds)
|
|
self.assertEqual(position_ids.shape, expected_positions.shape)
|
|
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
|
|
|
|
|
@require_torch
|
|
class EsmModelIntegrationTest(TestCasePlus):
|
|
@slow
|
|
def test_inference_masked_lm(self):
|
|
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm-2-8m")
|
|
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
|
output = model(input_ids)[0]
|
|
|
|
vocab_size = 33
|
|
|
|
expected_shape = torch.Size((1, 6, vocab_size))
|
|
self.assertEqual(output.shape, expected_shape)
|
|
|
|
expected_slice = torch.tensor(
|
|
[[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]]
|
|
)
|
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
|
|
|
@slow
|
|
def test_inference_no_head(self):
|
|
model = EsmModel.from_pretrained("Rocketknight1/esm-2-8m")
|
|
|
|
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
|
output = model(input_ids)[0]
|
|
# compare the actual values for a slice.
|
|
expected_slice = torch.tensor(
|
|
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
|
|
)
|
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
|
|
|
def test_lm_head_ignore_keys(self):
|
|
from copy import deepcopy
|
|
|
|
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
|
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
|
|
config = EsmConfig.from_pretrained("Rocketknight1/esm-2-8m")
|
|
config_tied = deepcopy(config)
|
|
config_tied.tie_word_embeddings = True
|
|
config_untied = deepcopy(config)
|
|
config_untied.tie_word_embeddings = False
|
|
for cls in [EsmForMaskedLM]:
|
|
model = cls(config_tied)
|
|
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)
|
|
|
|
# the keys should be different when embeddings aren't tied
|
|
model = cls(config_untied)
|
|
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)
|
|
|
|
# test that saving works with updated ignore keys - just testing that it doesn't fail
|
|
model.save_pretrained(self.get_auto_remove_tmp_dir())
|