mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add support for inheritance from class with different suffix in modular (#34077)
* add support for different suffix in modular * add dummy example, pull new changes for modular * nide lines order change
This commit is contained in:
parent
d314ce70bf
commit
65442718c4
546
examples/modular-transformers/modeling_new_task_model.py
Normal file
546
examples/modular-transformers/modeling_new_task_model.py
Normal file
@ -0,0 +1,546 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_new_task_model.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_new_task_model.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_new_task_model import NewTaskModelConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "NewTaskModelConfig"
|
||||
|
||||
|
||||
# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# But NewTaskModel has no causal mask on prefix
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
min_dtype: float,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_training: bool = False,
|
||||
token_type_ids: torch.Tensor = None,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
min_dtype (`float`):
|
||||
The minimum value representable with the dtype `dtype`.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
is_training (`bool`):
|
||||
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
if is_training:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
else:
|
||||
causal_mask[:, :sequence_length] = 0.0
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
||||
if is_training:
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewTaskModelCausalLMOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for NewTaskModelcausal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class NewTaskModelMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: NewTaskModelConfig):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear(image_features)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
NEW_TASK_MODEL_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`NewTaskModelConfig`] or [`NewTaskModelVisionConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
NEW_TASK_MODEL_START_DOCSTRING,
|
||||
)
|
||||
class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
config_class = NewTaskModelConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["NewTaskModelMultiModalProjector"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of NewTaskModelisn't meant for training from scratch - only
|
||||
# inference and fine-tuning
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
@property
|
||||
def _supports_sdpa(self):
|
||||
"""
|
||||
Retrieve language_model's attribute to check whether the model supports
|
||||
SDPA or not.
|
||||
"""
|
||||
return self.language_model._supports_sdpa
|
||||
|
||||
|
||||
NEW_TASK_MODEL_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input images. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`NewTaskModelProcessor`] uses
|
||||
[`SiglipImageProcessor`] for processing images).
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The NEW_TASK_MODEL model which consists of a vision backbone and a language model.""",
|
||||
NEW_TASK_MODEL_START_DOCSTRING,
|
||||
)
|
||||
class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.multi_modal_projector = NewTaskModelMultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self._attn_implementation = config._attn_implementation
|
||||
|
||||
language_model = AutoModelForCausalLM.from_config(
|
||||
config=config.text_config, attn_implementation=self._attn_implementation
|
||||
)
|
||||
|
||||
if language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
||||
self.language_model = language_model
|
||||
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
|
||||
self.embedding_dim = self.config.embedding_dim
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def tie_weights(self):
|
||||
return self.language_model.tie_weights()
|
||||
|
||||
def _update_causal_mask(
|
||||
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
|
||||
):
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
dtype = inputs_embeds.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = inputs_embeds.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_length()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
if is_training:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
else:
|
||||
causal_mask[:, :sequence_length] = 0.0
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
||||
if is_training:
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, NewTaskModelCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, NewTaskModelForNewTask
|
||||
|
||||
>>> model = NewTaskModelForNewTask.from_pretrained("google/NewTaskModel-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
>>> url = "https://huggingface.co/gv-hf/NewTaskModel-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"answer en Where is the cow standing?\nbeach"
|
||||
```
|
||||
Returns:
|
||||
"""
|
||||
vlm_outputs = super().forward(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
token_type_ids=token_type_ids,
|
||||
cache_position=cache_position,
|
||||
inputs_embeds=inputs_embeds,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
)
|
||||
last_hidden_states = vlm_outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
||||
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
|
||||
|
||||
# L2 normalization
|
||||
embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
||||
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
return (embeddings,) + vlm_outputs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
use_cache=use_cache,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.get_output_embeddings().weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_length(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
min_dtype=min_dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
model_inputs["token_type_ids"] = token_type_ids
|
||||
|
||||
# position_ids in NewTaskModel are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
return model_inputs
|
||||
|
||||
def resize_token_embeddings(
|
||||
self,
|
||||
new_num_tokens: Optional[int] = None,
|
||||
pad_to_multiple_of=None,
|
||||
) -> nn.Embedding:
|
||||
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
|
||||
# Update vocab size
|
||||
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
||||
self.config.vocab_size = model_embeds.num_embeddings
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
|
||||
return model_embeds
|
84
examples/modular-transformers/modular_new_task_model.py
Normal file
84
examples/modular-transformers/modular_new_task_model.py
Normal file
@ -0,0 +1,84 @@
|
||||
from typing import ClassVar, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from ...cache_utils import Cache
|
||||
|
||||
|
||||
class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
|
||||
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config=config)
|
||||
|
||||
self.embedding_dim = self.config.embedding_dim
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
"""
|
||||
vlm_outputs = super().forward(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
token_type_ids=token_type_ids,
|
||||
cache_position=cache_position,
|
||||
inputs_embeds=inputs_embeds,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
)
|
||||
last_hidden_states = vlm_outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
||||
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
|
||||
|
||||
# L2 normalization
|
||||
embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
||||
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
return (embeddings,) + vlm_outputs
|
||||
|
||||
def resize_token_embeddings(
|
||||
self,
|
||||
new_num_tokens: Optional[int] = None,
|
||||
pad_to_multiple_of=None,
|
||||
) -> nn.Embedding:
|
||||
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
|
||||
# Update vocab size
|
||||
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
||||
self.config.vocab_size = model_embeds.num_embeddings
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
|
||||
return model_embeds
|
@ -204,7 +204,15 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
- LLaMa -> MyNewModel abd MyNewModel -> Llama
|
||||
"""
|
||||
|
||||
def __init__(self, old_name, new_name, given_old_name=None, given_new_name=None):
|
||||
def __init__(
|
||||
self,
|
||||
old_name,
|
||||
new_name,
|
||||
given_old_name=None,
|
||||
given_new_name=None,
|
||||
old_class_name: str = None,
|
||||
new_class_name: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
@ -220,6 +228,18 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
}
|
||||
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
|
||||
self.patterns[given_old_name] = given_new_name
|
||||
if self.old_name in CONFIG_MAPPING_NAMES:
|
||||
self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
|
||||
if self.default_old_name.isupper():
|
||||
self.default_old_name = self.default_old_name.capitalize()
|
||||
if new_class_name is not None and old_class_name is not None and old_class_name not in self.patterns:
|
||||
# In last recourse, when the suffix of the new class is not the same as the old class,
|
||||
# and if the old and new classes start with the default name, we keep the default class name
|
||||
# and replace the old suffix with the new one.
|
||||
# Useful when we have a class like `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`
|
||||
# where a model extends another model, but is used for a different task.
|
||||
if old_class_name.startswith(self.default_old_name) and new_class_name.startswith(self.default_name):
|
||||
self.patterns[old_class_name[len(self.default_old_name) :]] = new_class_name[len(self.default_name) :]
|
||||
|
||||
def preserve_case_replace(self, text):
|
||||
# Create a regex pattern to match all variations
|
||||
@ -235,7 +255,9 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
|
||||
def convert_to_camelcase(self, text):
|
||||
# Regex pattern to match consecutive uppercase letters and lowercase the first set
|
||||
result = re.sub(r"^[A-Z]+(?=[A-Z][a-z])", lambda m: m.group(0).capitalize(), text, count=1)
|
||||
result = re.sub(
|
||||
rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1
|
||||
)
|
||||
return result
|
||||
|
||||
@m.leave(m.Name() | m.SimpleString() | m.Comment())
|
||||
@ -249,9 +271,24 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value)))
|
||||
|
||||
|
||||
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
|
||||
def find_classes_in_file(
|
||||
module: cst.Module,
|
||||
old_id="llama",
|
||||
new_id="gemma",
|
||||
given_old_name=None,
|
||||
given_new_name=None,
|
||||
old_class_name=None,
|
||||
new_class_name=None,
|
||||
):
|
||||
"""Helper function to rename and then parse a source file using the ClassFinder"""
|
||||
transformer = ReplaceNameTransformer(old_id, new_id, given_old_name, given_new_name)
|
||||
transformer = ReplaceNameTransformer(
|
||||
old_id,
|
||||
new_id,
|
||||
given_old_name=given_old_name,
|
||||
given_new_name=given_new_name,
|
||||
old_class_name=old_class_name,
|
||||
new_class_name=new_class_name,
|
||||
)
|
||||
new_module = module.visit(transformer)
|
||||
|
||||
wrapper = MetadataWrapper(new_module)
|
||||
@ -868,7 +905,7 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||
}
|
||||
if list_dependencies == []:
|
||||
if len(list_dependencies) == 0:
|
||||
# so, maybe standard renaming did not work (the class name is different)
|
||||
# we try with another renaming pattern
|
||||
potential_given_name = get_new_part(class_name, super_class)
|
||||
@ -884,6 +921,30 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||
}
|
||||
if len(list_dependencies) == 0:
|
||||
# last recourse, if the suffix of the new class is different from the one of the super class
|
||||
# e.g. MyNewClassForSegmentation extends MyOldClassForObjectDetection
|
||||
# we try with another renaming pattern
|
||||
class_finder = find_classes_in_file(
|
||||
self.transformers_imports[super_file_name],
|
||||
model_name,
|
||||
self.model_name,
|
||||
self.given_old_name,
|
||||
self.given_new_name,
|
||||
super_class,
|
||||
class_name,
|
||||
)
|
||||
visited_module[super_file_name] = class_finder
|
||||
list_dependencies = {
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||
}
|
||||
if len(list_dependencies) == 0:
|
||||
raise ValueError(
|
||||
f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
|
||||
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
|
||||
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
|
||||
)
|
||||
|
||||
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
|
||||
start_insert_idx = self.global_scope_index
|
||||
@ -917,12 +978,6 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
|
||||
if len(list_dependencies) > 0:
|
||||
updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
|
||||
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
|
||||
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
|
||||
)
|
||||
|
||||
# Now, if a class was defined without parents, we look for the name
|
||||
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
|
||||
|
Loading…
Reference in New Issue
Block a user