[Pix2Struct] Add support to resize embeddings (#22394)

* First draft

* Fix integration test

* Remove script

* Fix test and typos

* Fix one more test

* Skip tied embeddings test

* Remove line

* Address comments
This commit is contained in:
NielsRogge 2023-03-27 17:38:07 +02:00 committed by GitHub
parent f6b80a0139
commit 0e708178ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 143 additions and 25 deletions

View File

@ -35,17 +35,16 @@ class Pix2StructTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Pix2StructTextModel`]. It is used to instantiate
a Pix2Struct text model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the `Pix2StructText` used by the
[base architectures](https://huggingface.co/google/pix2struct-textcaps-base).
configuration with the defaults will yield a similar configuration to that of the Pix2Struct text decoder used by
the [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50244):
Vocabulary size of the `Pix2Struct` text model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`Pix2StructModel`].
represented by the `inputs_ids` passed when calling [`Pix2StructTextModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
d_kv (`int`, *optional*, defaults to 64):
@ -83,10 +82,10 @@ class Pix2StructTextConfig(PretrainedConfig):
```python
>>> from transformers import Pix2StructTextConfig, Pix2StructTextModel
>>> # Initializing a Pix2StructTextConfig with Salesforce/pix2struct-vqa-base style configuration
>>> # Initializing a Pix2StructTextConfig with google/pix2struct-base style configuration
>>> configuration = Pix2StructTextConfig()
>>> # Initializing a Pix2StructTextModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration
>>> # Initializing a Pix2StructTextModel (with random weights) from the google/pix2struct-base style configuration
>>> model = Pix2StructTextModel(configuration)
>>> # Accessing the model configuration
@ -118,6 +117,7 @@ class Pix2StructTextConfig(PretrainedConfig):
use_cache=False,
pad_token_id=0,
eos_token_id=1,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
@ -143,6 +143,7 @@ class Pix2StructTextConfig(PretrainedConfig):
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@ -168,14 +169,13 @@ class Pix2StructTextConfig(PretrainedConfig):
class Pix2StructVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Pix2StructVisionModel`]. It is used to
instantiate a PIX2STRUCT vision model according to the specified arguments, defining the model architecture.
instantiate a Pix2Struct vision model according to the specified arguments, defining the model architecture.
Instantiating a configuration defaults will yield a similar configuration to that of the Pix2Struct-base
[Salesforce/pix2struct-vqa-base](https://huggingface.co/Salesforce/pix2struct-vqa-base) architecture.
[google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
@ -223,10 +223,10 @@ class Pix2StructVisionConfig(PretrainedConfig):
```python
>>> from transformers import Pix2StructVisionConfig, Pix2StructVisionModel
>>> # Initializing a Pix2StructVisionConfig with Salesforce/pix2struct-vqa-base style configuration
>>> # Initializing a Pix2StructVisionConfig with google/pix2struct-base style configuration
>>> configuration = Pix2StructVisionConfig()
>>> # Initializing a Pix2StructVisionModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration
>>> # Initializing a Pix2StructVisionModel (with random weights) from the google/pix2struct-base style configuration
>>> model = Pix2StructVisionModel(configuration)
>>> # Accessing the model configuration
@ -301,11 +301,11 @@ class Pix2StructVisionConfig(PretrainedConfig):
class Pix2StructConfig(PretrainedConfig):
r"""
[`Pix2StructConfig`] is the configuration class to store the configuration of a [`Pix2StructModel`]. It is used to
instantiate a PIX2STRUCT model according to the specified arguments, defining the text model and vision model
configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
PIX2STRUCT-base [Salesforce/pix2struct-vqa-base](https://huggingface.co/Salesforce/pix2struct-vqa-base)
architecture.
[`Pix2StructConfig`] is the configuration class to store the configuration of a
[`Pix2StructForConditionalGeneration`]. It is used to instantiate a Pix2Struct model according to the specified
arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will
yield a similar configuration to that of the Pix2Struct-base
[google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
@ -327,20 +327,20 @@ class Pix2StructConfig(PretrainedConfig):
Example:
```python
>>> from transformers import Pix2StructConfig, Pix2StructModel
>>> from transformers import Pix2StructConfig, Pix2StructForConditionalGeneration
>>> # Initializing a Pix2StructConfig with Salesforce/pix2struct-vqa-base style configuration
>>> # Initializing a Pix2StructConfig with google/pix2struct-base style configuration
>>> configuration = Pix2StructConfig()
>>> # Initializing a Pix2StructPModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration
>>> model = Pix2StructModel(configuration)
>>> # Initializing a Pix2StructForConditionalGeneration (with random weights) from the google/pix2struct-base style configuration
>>> model = Pix2StructForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a Pix2StructConfig from a Pix2StructTextConfig and a Pix2StructVisionConfig
>>> # Initializing a PIX2STRUCTText and PIX2STRUCTVision configuration
>>> # Initializing a Pix2Struct text and Pix2Struct vision configuration
>>> config_text = Pix2StructTextConfig()
>>> config_vision = Pix2StructVisionConfig()

View File

@ -1369,6 +1369,12 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(PIX2STRUCT_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1626,12 +1632,25 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
self.post_init()
def get_input_embeddings(self):
return self.shared
return self.decoder.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.decoder.set_input_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Module:
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.decoder.set_output_embeddings(new_embeddings)
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
model_embeds = self.decoder.resize_token_embeddings(new_num_tokens)
# update vocab size
self.config.text_config.vocab_size = new_num_tokens
return model_embeds
def get_decoder(self):
return self.decoder

View File

@ -14,7 +14,7 @@
# limitations under the License.
""" Testing suite for the PyTorch Pix2Struct model. """
import copy
import inspect
import os
import tempfile
@ -396,7 +396,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False
@ -526,6 +526,105 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig`
def test_resize_tokens_embeddings(self):
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
return
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
if self.model_tester.is_training is False:
model.eval()
model_vocab_size = config.text_config.vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Decoder input ids should be clamped to the maximum size of the vocabulary
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
# overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig`
def test_resize_embeddings_untied(self):
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
return
original_config.tie_word_embeddings = False
# if model cannot untied embeddings -> leave test
if original_config.tie_word_embeddings:
return
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config).to(torch_device)
# if no output embeddings -> leave test
if model.get_output_embeddings() is None:
continue
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.text_config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Decoder input ids should be clamped to the maximum size of the vocabulary
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
@unittest.skip(reason="Pix2Struct doesn't use tied weights")
def test_tied_model_weights_key_ignore(self):
pass
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
return