mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Large modular logic refactoring (#34487)
* rework converter * Update modular_model_converter.py * Update modular_model_converter.py * Update modular_model_converter.py * Update modular_model_converter.py * cleaning * cleaning * finalize imports * imports * Update modular_model_converter.py * Better renaming to avoid visiting same file multiple times * start converting files * style * address most comments * style * remove unused stuff in get_needed_imports * style * move class dependency functions outside class * Move main functions outside class * style * Update modular_model_converter.py * rename func * add augmented dependencies * Update modular_model_converter.py * Add types_to_file_type + tweak annotation handling * Allow assignment dependency mapping + fix regex * style + update modular examples * fix modular_roberta example (wrong redefinition of __init__) * slightly correct order in which dependencies will appear * style * review comments * Performance + better handling of dependencies when they are imported * style * Add advanced new classes capabilities * style * add forgotten check * Update modeling_llava_next_video.py * Add prority list ordering in check_conversion as well * Update check_modular_conversion.py * Update configuration_gemma.py
This commit is contained in:
parent
86701f2b6f
commit
e2ac16b28a
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from examples/modular-transformers/modular_my_new_model.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_my_new_model.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...modeling_rope_utils import rope_config_validation
|
from ...modeling_rope_utils import rope_config_validation
|
||||||
@ -158,6 +158,13 @@ class MyNewModelConfig(PretrainedConfig):
|
|||||||
new_param=0,
|
new_param=0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -187,11 +194,3 @@ class MyNewModelConfig(PretrainedConfig):
|
|||||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
rope_config_validation(self)
|
rope_config_validation(self)
|
||||||
self.new_param = new_param
|
self.new_param = new_param
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from examples/modular-transformers/modular_my_new_model2.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_my_new_model2.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...modeling_rope_utils import rope_config_validation
|
from ...modeling_rope_utils import rope_config_validation
|
||||||
@ -11,106 +11,6 @@ from ...modeling_rope_utils import rope_config_validation
|
|||||||
|
|
||||||
class MyNewModel2Config(PretrainedConfig):
|
class MyNewModel2Config(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`MyNewModel2Model`]. It is used to instantiate an MyNewModel2
|
|
||||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
|
||||||
defaults will yield a similar configuration to that of the MyNewModel2-7B.
|
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab_size (`int`, *optional*, defaults to 32000):
|
|
||||||
Vocabulary size of the MyNewModel2 model. Defines the number of different tokens that can be represented by the
|
|
||||||
`inputs_ids` passed when calling [`MyNewModel2Model`]
|
|
||||||
hidden_size (`int`, *optional*, defaults to 4096):
|
|
||||||
Dimension of the hidden representations.
|
|
||||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
|
||||||
Dimension of the MLP representations.
|
|
||||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
|
||||||
Number of hidden layers in the Transformer decoder.
|
|
||||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
|
||||||
Number of attention heads for each attention layer in the Transformer decoder.
|
|
||||||
num_key_value_heads (`int`, *optional*):
|
|
||||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
||||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
||||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
||||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
||||||
by meanpooling all the original heads within that group. For more details checkout [this
|
|
||||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
|
||||||
`num_attention_heads`.
|
|
||||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
|
||||||
The non-linear activation function (function or string) in the decoder.
|
|
||||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
|
||||||
The maximum sequence length that this model might ever be used with. MyNewModel2 1 supports up to 2048 tokens,
|
|
||||||
MyNewModel2 2 up to 4096, CodeMyNewModel2 up to 16384.
|
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
||||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
||||||
The epsilon used by the rms normalization layers.
|
|
||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
||||||
relevant if `config.is_decoder=True`.
|
|
||||||
pad_token_id (`int`, *optional*):
|
|
||||||
Padding token id.
|
|
||||||
bos_token_id (`int`, *optional*, defaults to 1):
|
|
||||||
Beginning of stream token id.
|
|
||||||
eos_token_id (`int`, *optional*, defaults to 2):
|
|
||||||
End of stream token id.
|
|
||||||
pretraining_tp (`int`, *optional*, defaults to 1):
|
|
||||||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
|
||||||
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
|
||||||
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
|
||||||
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
|
||||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to tie weight embeddings
|
|
||||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
|
||||||
The base period of the RoPE embeddings.
|
|
||||||
rope_scaling (`Dict`, *optional*):
|
|
||||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
|
||||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
|
||||||
accordingly.
|
|
||||||
Expected contents:
|
|
||||||
`rope_type` (`str`):
|
|
||||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
|
||||||
'my_new_model23'], with 'default' being the original RoPE implementation.
|
|
||||||
`factor` (`float`, *optional*):
|
|
||||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
|
||||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
|
||||||
original maximum pre-trained length.
|
|
||||||
`original_max_position_embeddings` (`int`, *optional*):
|
|
||||||
Used with 'dynamic', 'longrope' and 'my_new_model23'. The original max position embeddings used during
|
|
||||||
pretraining.
|
|
||||||
`attention_factor` (`float`, *optional*):
|
|
||||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
|
||||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
|
||||||
`factor` field to infer the suggested value.
|
|
||||||
`beta_fast` (`float`, *optional*):
|
|
||||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
|
||||||
ramp function. If unspecified, it defaults to 32.
|
|
||||||
`beta_slow` (`float`, *optional*):
|
|
||||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
|
||||||
ramp function. If unspecified, it defaults to 1.
|
|
||||||
`short_factor` (`List[float]`, *optional*):
|
|
||||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
|
||||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
|
||||||
size divided by the number of attention heads divided by 2
|
|
||||||
`long_factor` (`List[float]`, *optional*):
|
|
||||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
|
||||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
|
||||||
size divided by the number of attention heads divided by 2
|
|
||||||
`low_freq_factor` (`float`, *optional*):
|
|
||||||
Only used with 'my_new_model23'. Scaling factor applied to low frequency components of the RoPE
|
|
||||||
`high_freq_factor` (`float`, *optional*):
|
|
||||||
Only used with 'my_new_model23'. Scaling factor applied to high frequency components of the RoPE
|
|
||||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
||||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
||||||
The dropout ratio for the attention probabilities.
|
|
||||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
|
||||||
head_dim (`int`, *optional*):
|
|
||||||
The attention head dimension. If None, it will default to hidden_size // num_heads
|
|
||||||
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
||||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
defaults will yield a similar configuration to that of the Gemma-7B.
|
defaults will yield a similar configuration to that of the Gemma-7B.
|
||||||
@ -121,7 +21,6 @@ class MyNewModel2Config(PretrainedConfig):
|
|||||||
vocab_size (`int`, *optional*, defaults to 256000):
|
vocab_size (`int`, *optional*, defaults to 256000):
|
||||||
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
||||||
`inputs_ids` passed when calling [`GemmaModel`]
|
`inputs_ids` passed when calling [`GemmaModel`]
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import GemmaModel, GemmaConfig
|
>>> from transformers import GemmaModel, GemmaConfig
|
||||||
>>> # Initializing a Gemma gemma-7b style configuration
|
>>> # Initializing a Gemma gemma-7b style configuration
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from examples/modular-transformers/modular_new_model.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_new_model.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# Example where we only want to overwrite the defaults of an init
|
# Example where we only want to overwrite the defaults of an init
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
@ -104,6 +104,13 @@ class NewModelConfig(PretrainedConfig):
|
|||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -121,14 +128,6 @@ class NewModelConfig(PretrainedConfig):
|
|||||||
self.attention_bias = attention_bias
|
self.attention_bias = attention_bias
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_heads(self):
|
def num_heads(self):
|
||||||
return self.num_attention_heads
|
return self.num_attention_heads
|
||||||
|
@ -1,26 +1,24 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from examples/modular-transformers/modular_dummy.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_dummy.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import BaseModelOutputWithPast
|
||||||
BaseModelOutputWithPast,
|
|
||||||
)
|
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...processing_utils import Unpack
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
@ -33,59 +31,6 @@ from .configuration_dummy import DummyConfig
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
sequence_length: int,
|
|
||||||
target_length: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
min_dtype: float,
|
|
||||||
cache_position: torch.Tensor,
|
|
||||||
batch_size: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
|
||||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask (`torch.Tensor`):
|
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
|
||||||
The sequence length being processed.
|
|
||||||
target_length (`int`):
|
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
|
||||||
The dtype to use for the 4D attention mask.
|
|
||||||
device (`torch.device`):
|
|
||||||
The device to plcae the 4D attention mask on.
|
|
||||||
min_dtype (`float`):
|
|
||||||
The minimum value representable with the dtype `dtype`.
|
|
||||||
cache_position (`torch.Tensor`):
|
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
|
||||||
batch_size (`torch.Tensor`):
|
|
||||||
Batch size.
|
|
||||||
"""
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
||||||
causal_mask = attention_mask
|
|
||||||
else:
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
||||||
if sequence_length != 1:
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
class DummyRMSNorm(nn.Module):
|
class DummyRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@ -193,40 +138,6 @@ class DummyRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 4]
|
|
||||||
x2 = x[..., x.shape[-1] // 4 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q (`torch.Tensor`): The query tensor.
|
|
||||||
k (`torch.Tensor`): The key tensor.
|
|
||||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
||||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
||||||
position_ids (`torch.Tensor`, *optional*):
|
|
||||||
Deprecated and unused.
|
|
||||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
||||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
||||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
||||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
||||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
||||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
||||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
||||||
Returns:
|
|
||||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
||||||
"""
|
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class DummyMLP(nn.Module):
|
class DummyMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -261,6 +172,40 @@ class DummyMLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 4]
|
||||||
|
x2 = x[..., x.shape[-1] // 4 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
|
Deprecated and unused.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
@ -423,6 +368,7 @@ class DummyFlashAttention2(DummyAttention):
|
|||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if isinstance(past_key_value, StaticCache):
|
if isinstance(past_key_value, StaticCache):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -507,6 +453,7 @@ class DummyFlashAttention2(DummyAttention):
|
|||||||
sliding_window=getattr(self, "sliding_window", None),
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
is_causal=self.is_causal,
|
is_causal=self.is_causal,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
@ -871,6 +818,7 @@ class DummyModel(DummyPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -952,6 +900,7 @@ class DummyModel(DummyPreTrainedModel):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
|
**flash_attn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@ -1011,10 +960,9 @@ class DummyModel(DummyPreTrainedModel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
if using_static_cache:
|
||||||
target_length = past_key_values.get_max_length()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@ -1023,13 +971,12 @@ class DummyModel(DummyPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
target_length=target_length,
|
target_length=target_length,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
min_dtype=min_dtype,
|
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
)
|
)
|
||||||
@ -1043,6 +990,63 @@ class DummyModel(DummyPreTrainedModel):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||||
|
`(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||||
|
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
@ -1,27 +1,20 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from examples/modular-transformers/modular_dummy_bert.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_dummy_bert.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_attn_mask_utils import (
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
_prepare_4d_attention_mask_for_sdpa,
|
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
from ...modeling_outputs import (
|
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
|
||||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
||||||
)
|
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@ -40,79 +33,6 @@ _CHECKPOINT_FOR_DOC = "google-dummy_bert/dummy_bert-base-uncased"
|
|||||||
_CONFIG_FOR_DOC = "DummyBertConfig"
|
_CONFIG_FOR_DOC = "DummyBertConfig"
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path):
|
|
||||||
"""Load tf checkpoints in a pytorch model."""
|
|
||||||
try:
|
|
||||||
import re
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
except ImportError:
|
|
||||||
logger.error(
|
|
||||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
|
||||||
"https://www.tensorflow.org/install/ for installation instructions."
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
|
||||||
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
|
||||||
# Load weights from TF model
|
|
||||||
init_vars = tf.train.list_variables(tf_path)
|
|
||||||
names = []
|
|
||||||
arrays = []
|
|
||||||
for name, shape in init_vars:
|
|
||||||
logger.info(f"Loading TF weight {name} with shape {shape}")
|
|
||||||
array = tf.train.load_variable(tf_path, name)
|
|
||||||
names.append(name)
|
|
||||||
arrays.append(array)
|
|
||||||
|
|
||||||
for name, array in zip(names, arrays):
|
|
||||||
name = name.split("/")
|
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
|
||||||
# which are not required for using pretrained model
|
|
||||||
if any(
|
|
||||||
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
|
||||||
for n in name
|
|
||||||
):
|
|
||||||
logger.info(f"Skipping {'/'.join(name)}")
|
|
||||||
continue
|
|
||||||
pointer = model
|
|
||||||
for m_name in name:
|
|
||||||
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
|
||||||
scope_names = re.split(r"_(\d+)", m_name)
|
|
||||||
else:
|
|
||||||
scope_names = [m_name]
|
|
||||||
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
|
||||||
pointer = getattr(pointer, "weight")
|
|
||||||
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
|
||||||
pointer = getattr(pointer, "bias")
|
|
||||||
elif scope_names[0] == "output_weights":
|
|
||||||
pointer = getattr(pointer, "weight")
|
|
||||||
elif scope_names[0] == "squad":
|
|
||||||
pointer = getattr(pointer, "classifier")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
pointer = getattr(pointer, scope_names[0])
|
|
||||||
except AttributeError:
|
|
||||||
logger.info(f"Skipping {'/'.join(name)}")
|
|
||||||
continue
|
|
||||||
if len(scope_names) >= 2:
|
|
||||||
num = int(scope_names[1])
|
|
||||||
pointer = pointer[num]
|
|
||||||
if m_name[-11:] == "_embeddings":
|
|
||||||
pointer = getattr(pointer, "weight")
|
|
||||||
elif m_name == "kernel":
|
|
||||||
array = np.transpose(array)
|
|
||||||
try:
|
|
||||||
if pointer.shape != array.shape:
|
|
||||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
|
||||||
except ValueError as e:
|
|
||||||
e.args += (pointer.shape, array.shape)
|
|
||||||
raise
|
|
||||||
logger.info(f"Initialize PyTorch weight {name}")
|
|
||||||
pointer.data = torch.from_numpy(array)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class DummyBertEmbeddings(nn.Module):
|
class DummyBertEmbeddings(nn.Module):
|
||||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||||
|
|
||||||
@ -706,6 +626,79 @@ class DummyBertPooler(nn.Module):
|
|||||||
return pooled_output
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path):
|
||||||
|
"""Load tf checkpoints in a pytorch model."""
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||||
|
"https://www.tensorflow.org/install/ for installation instructions."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
|
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
||||||
|
# Load weights from TF model
|
||||||
|
init_vars = tf.train.list_variables(tf_path)
|
||||||
|
names = []
|
||||||
|
arrays = []
|
||||||
|
for name, shape in init_vars:
|
||||||
|
logger.info(f"Loading TF weight {name} with shape {shape}")
|
||||||
|
array = tf.train.load_variable(tf_path, name)
|
||||||
|
names.append(name)
|
||||||
|
arrays.append(array)
|
||||||
|
|
||||||
|
for name, array in zip(names, arrays):
|
||||||
|
name = name.split("/")
|
||||||
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
|
# which are not required for using pretrained model
|
||||||
|
if any(
|
||||||
|
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||||
|
for n in name
|
||||||
|
):
|
||||||
|
logger.info(f"Skipping {'/'.join(name)}")
|
||||||
|
continue
|
||||||
|
pointer = model
|
||||||
|
for m_name in name:
|
||||||
|
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||||||
|
scope_names = re.split(r"_(\d+)", m_name)
|
||||||
|
else:
|
||||||
|
scope_names = [m_name]
|
||||||
|
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
||||||
|
pointer = getattr(pointer, "weight")
|
||||||
|
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
||||||
|
pointer = getattr(pointer, "bias")
|
||||||
|
elif scope_names[0] == "output_weights":
|
||||||
|
pointer = getattr(pointer, "weight")
|
||||||
|
elif scope_names[0] == "squad":
|
||||||
|
pointer = getattr(pointer, "classifier")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
pointer = getattr(pointer, scope_names[0])
|
||||||
|
except AttributeError:
|
||||||
|
logger.info(f"Skipping {'/'.join(name)}")
|
||||||
|
continue
|
||||||
|
if len(scope_names) >= 2:
|
||||||
|
num = int(scope_names[1])
|
||||||
|
pointer = pointer[num]
|
||||||
|
if m_name[-11:] == "_embeddings":
|
||||||
|
pointer = getattr(pointer, "weight")
|
||||||
|
elif m_name == "kernel":
|
||||||
|
array = np.transpose(array)
|
||||||
|
try:
|
||||||
|
if pointer.shape != array.shape:
|
||||||
|
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||||
|
except ValueError as e:
|
||||||
|
e.args += (pointer.shape, array.shape)
|
||||||
|
raise
|
||||||
|
logger.info(f"Initialize PyTorch weight {name}")
|
||||||
|
pointer.data = torch.from_numpy(array)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class DummyBertPreTrainedModel(PreTrainedModel):
|
class DummyBertPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@ -871,26 +864,6 @@ class DummyBertModel(DummyBertPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||||
r"""
|
|
||||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
|
||||||
the model is configured as a decoder.
|
|
||||||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
|
||||||
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
||||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
||||||
`past_key_values`).
|
|
||||||
"""
|
|
||||||
r"""
|
r"""
|
||||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
@ -1027,7 +1000,6 @@ class DummyBertModel(DummyBertPreTrainedModel):
|
|||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||||
return super().forward(input_ids)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
|
@ -1,25 +1,20 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from examples/modular-transformers/modular_my_new_model2.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_my_new_model2.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
|
||||||
BaseModelOutputWithPast,
|
|
||||||
SequenceClassifierOutputWithPast,
|
|
||||||
)
|
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@ -30,6 +25,9 @@ from ...utils import (
|
|||||||
from .configuration_my_new_model2 import MyNewModel2Config
|
from .configuration_my_new_model2 import MyNewModel2Config
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MyNewModel2RMSNorm(nn.Module):
|
class MyNewModel2RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -50,9 +48,6 @@ class MyNewModel2RMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MyNewModel2RotaryEmbedding(nn.Module):
|
class MyNewModel2RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -448,59 +443,6 @@ class MyNewModel2FlashAttention2(MyNewModel2Attention):
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
sequence_length: int,
|
|
||||||
target_length: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
min_dtype: float,
|
|
||||||
cache_position: torch.Tensor,
|
|
||||||
batch_size: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
|
||||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask (`torch.Tensor`):
|
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
|
||||||
The sequence length being processed.
|
|
||||||
target_length (`int`):
|
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
|
||||||
The dtype to use for the 4D attention mask.
|
|
||||||
device (`torch.device`):
|
|
||||||
The device to plcae the 4D attention mask on.
|
|
||||||
min_dtype (`float`):
|
|
||||||
The minimum value representable with the dtype `dtype`.
|
|
||||||
cache_position (`torch.Tensor`):
|
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
|
||||||
batch_size (`torch.Tensor`):
|
|
||||||
Batch size.
|
|
||||||
"""
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
||||||
causal_mask = attention_mask
|
|
||||||
else:
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
||||||
if sequence_length != 1:
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
MY_NEW_MODEL2_ATTENTION_CLASSES = {
|
MY_NEW_MODEL2_ATTENTION_CLASSES = {
|
||||||
"eager": MyNewModel2Attention,
|
"eager": MyNewModel2Attention,
|
||||||
"flash_attention_2": MyNewModel2FlashAttention2,
|
"flash_attention_2": MyNewModel2FlashAttention2,
|
||||||
@ -893,10 +835,9 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
if using_static_cache:
|
||||||
target_length = past_key_values.get_max_length()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@ -905,13 +846,12 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
target_length=target_length,
|
target_length=target_length,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
min_dtype=min_dtype,
|
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
)
|
)
|
||||||
@ -925,10 +865,67 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||||
|
`(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||||
|
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
@ -1019,27 +1016,8 @@ class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = labels.to(logits.device)
|
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||||
if self.config.problem_type is None:
|
|
||||||
if self.num_labels == 1:
|
|
||||||
self.config.problem_type = "regression"
|
|
||||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
||||||
self.config.problem_type = "single_label_classification"
|
|
||||||
else:
|
|
||||||
self.config.problem_type = "multi_label_classification"
|
|
||||||
|
|
||||||
if self.config.problem_type == "regression":
|
|
||||||
loss_fct = MSELoss()
|
|
||||||
if self.num_labels == 1:
|
|
||||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
|
||||||
else:
|
|
||||||
loss = loss_fct(pooled_logits, labels)
|
|
||||||
elif self.config.problem_type == "single_label_classification":
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
elif self.config.problem_type == "multi_label_classification":
|
|
||||||
loss_fct = BCEWithLogitsLoss()
|
|
||||||
loss = loss_fct(pooled_logits, labels)
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (pooled_logits,) + transformer_outputs[1:]
|
output = (pooled_logits,) + transformer_outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
@ -8,7 +8,6 @@ from dataclasses import dataclass
|
|||||||
from typing import ClassVar, List, Optional, Tuple, Union
|
from typing import ClassVar, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...cache_utils import Cache, StaticCache
|
from ...cache_utils import Cache, StaticCache
|
||||||
@ -18,92 +17,15 @@ from ...utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_2_available,
|
|
||||||
logging,
|
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
from ..auto import AutoModel, AutoModelForCausalLM
|
||||||
from .configuration_new_task_model import NewTaskModelConfig
|
from .configuration_new_task_model import NewTaskModelConfig
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
||||||
|
|
||||||
from ..auto import AutoModel, AutoModelForCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "NewTaskModelConfig"
|
_CONFIG_FOR_DOC = "NewTaskModelConfig"
|
||||||
|
|
||||||
|
|
||||||
# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
|
||||||
# But NewTaskModel has no causal mask on prefix
|
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
sequence_length: int,
|
|
||||||
target_length: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
min_dtype: float,
|
|
||||||
cache_position: torch.Tensor,
|
|
||||||
batch_size: int,
|
|
||||||
is_training: bool = False,
|
|
||||||
token_type_ids: torch.Tensor = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
|
||||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask (`torch.Tensor`):
|
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
|
||||||
The sequence length being processed.
|
|
||||||
target_length (`int`):
|
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
|
||||||
The dtype to use for the 4D attention mask.
|
|
||||||
device (`torch.device`):
|
|
||||||
The device to plcae the 4D attention mask on.
|
|
||||||
min_dtype (`float`):
|
|
||||||
The minimum value representable with the dtype `dtype`.
|
|
||||||
cache_position (`torch.Tensor`):
|
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
|
||||||
batch_size (`torch.Tensor`):
|
|
||||||
Batch size.
|
|
||||||
is_training (`bool`):
|
|
||||||
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
|
|
||||||
"""
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
||||||
causal_mask = attention_mask
|
|
||||||
else:
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
||||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
|
||||||
if sequence_length != 1:
|
|
||||||
if is_training:
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
||||||
else:
|
|
||||||
causal_mask[:, :sequence_length] = 0.0
|
|
||||||
|
|
||||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
|
||||||
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
|
||||||
if is_training:
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
|
||||||
)
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NewTaskModelCausalLMOutputWithPast(ModelOutput):
|
class NewTaskModelCausalLMOutputWithPast(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@ -182,12 +104,12 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["NewTaskModelMultiModalProjector"]
|
_no_split_modules = ["NewTaskModelMultiModalProjector"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_quantized_cache = True
|
_supports_quantized_cache = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
_supports_sdpa = True
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
# important: this ported version of NewTaskModelisn't meant for training from scratch - only
|
# important: this ported version of NewTaskModelisn't meant for training from scratch - only
|
||||||
@ -210,14 +132,6 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
|||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
@property
|
|
||||||
def _supports_sdpa(self):
|
|
||||||
"""
|
|
||||||
Retrieve language_model's attribute to check whether the model supports
|
|
||||||
SDPA or not.
|
|
||||||
"""
|
|
||||||
return self.language_model._supports_sdpa
|
|
||||||
|
|
||||||
|
|
||||||
NEW_TASK_MODEL_INPUTS_DOCSTRING = r"""
|
NEW_TASK_MODEL_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
@ -301,11 +215,8 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||||
self.multi_modal_projector = NewTaskModelMultiModalProjector(config)
|
self.multi_modal_projector = NewTaskModelMultiModalProjector(config)
|
||||||
self.vocab_size = config.text_config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
self._attn_implementation = config._attn_implementation
|
|
||||||
|
|
||||||
language_model = AutoModelForCausalLM.from_config(
|
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
|
||||||
config=config.text_config, attn_implementation=self._attn_implementation
|
|
||||||
)
|
|
||||||
|
|
||||||
if language_model._tied_weights_keys is not None:
|
if language_model._tied_weights_keys is not None:
|
||||||
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
||||||
@ -344,6 +255,11 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
|
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
|
||||||
):
|
):
|
||||||
|
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
dtype = inputs_embeds.dtype
|
dtype = inputs_embeds.dtype
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
@ -388,6 +304,22 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
)
|
)
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||||
|
"""
|
||||||
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
Returns:
|
||||||
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
|
"""
|
||||||
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
|
selected_image_feature = image_outputs.last_hidden_state
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
image_features = image_features / (self.config.hidden_size**0.5)
|
||||||
|
return image_features
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@ -426,9 +358,9 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
```python
|
```python
|
||||||
>>> from PIL import Image
|
>>> from PIL import Image
|
||||||
>>> import requests
|
>>> import requests
|
||||||
>>> from transformers import AutoProcessor, NewTaskModelForNewTask
|
>>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration
|
||||||
|
|
||||||
>>> model = NewTaskModelForNewTask.from_pretrained("google/NewTaskModel-test-224px-hf")
|
>>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/NewTaskModel-test-224px-hf")
|
||||||
>>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf")
|
>>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf")
|
||||||
|
|
||||||
>>> prompt = "answer en Where is the cow standing?"
|
>>> prompt = "answer en Where is the cow standing?"
|
||||||
@ -484,6 +416,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
num_logits_to_keep=None,
|
num_logits_to_keep=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@ -493,33 +426,10 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
num_logits_to_keep=num_logits_to_keep,
|
num_logits_to_keep=num_logits_to_keep,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
|
||||||
if model_inputs["inputs_embeds"] is not None:
|
|
||||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
|
||||||
device = model_inputs["inputs_embeds"].device
|
|
||||||
else:
|
|
||||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
|
||||||
device = model_inputs["input_ids"].device
|
|
||||||
|
|
||||||
dtype = self.get_output_embeddings().weight.dtype
|
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
|
|
||||||
model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position(
|
|
||||||
attention_mask,
|
|
||||||
sequence_length=sequence_length,
|
|
||||||
target_length=past_key_values.get_max_length(),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
min_dtype=min_dtype,
|
|
||||||
cache_position=cache_position,
|
|
||||||
batch_size=batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_inputs["token_type_ids"] = token_type_ids
|
|
||||||
|
|
||||||
# position_ids in NewTaskModel are 1-indexed
|
# position_ids in NewTaskModel are 1-indexed
|
||||||
if model_inputs.get("position_ids") is not None:
|
if model_inputs.get("position_ids") is not None:
|
||||||
model_inputs["position_ids"] += 1
|
model_inputs["position_ids"] += 1
|
||||||
|
1014
examples/modular-transformers/modeling_roberta.py
Normal file
1014
examples/modular-transformers/modeling_roberta.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,26 +1,24 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from examples/modular-transformers/modular_super.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# diff.py file directly. One of our CI enforces this
|
# modular_super.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, StaticCache
|
from ...cache_utils import Cache, StaticCache
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import BaseModelOutputWithPast
|
||||||
BaseModelOutputWithPast,
|
|
||||||
)
|
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...processing_utils import Unpack
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
@ -33,59 +31,6 @@ from .configuration_super import SuperConfig
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
sequence_length: int,
|
|
||||||
target_length: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
min_dtype: float,
|
|
||||||
cache_position: torch.Tensor,
|
|
||||||
batch_size: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
|
||||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask (`torch.Tensor`):
|
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
|
||||||
The sequence length being processed.
|
|
||||||
target_length (`int`):
|
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
|
||||||
The dtype to use for the 4D attention mask.
|
|
||||||
device (`torch.device`):
|
|
||||||
The device to plcae the 4D attention mask on.
|
|
||||||
min_dtype (`float`):
|
|
||||||
The minimum value representable with the dtype `dtype`.
|
|
||||||
cache_position (`torch.Tensor`):
|
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
|
||||||
batch_size (`torch.Tensor`):
|
|
||||||
Batch size.
|
|
||||||
"""
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
||||||
causal_mask = attention_mask
|
|
||||||
else:
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
||||||
if sequence_length != 1:
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
class SuperRMSNorm(nn.Module):
|
class SuperRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@ -123,7 +68,7 @@ class SuperRotaryEmbedding(nn.Module):
|
|||||||
if config is None:
|
if config is None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`SuperRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
"`SuperRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||||
"`config` argument. All other arguments will be removed in v4.45"
|
"`config` argument. All other arguments will be removed in v4.46"
|
||||||
)
|
)
|
||||||
self.rope_kwargs = {
|
self.rope_kwargs = {
|
||||||
"rope_type": rope_type,
|
"rope_type": rope_type,
|
||||||
@ -193,40 +138,6 @@ class SuperRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q (`torch.Tensor`): The query tensor.
|
|
||||||
k (`torch.Tensor`): The key tensor.
|
|
||||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
||||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
||||||
position_ids (`torch.Tensor`, *optional*):
|
|
||||||
Deprecated and unused.
|
|
||||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
||||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
||||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
||||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
||||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
||||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
||||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
||||||
Returns:
|
|
||||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
||||||
"""
|
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class SuperMLP(nn.Module):
|
class SuperMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -261,6 +172,40 @@ class SuperMLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
|
Deprecated and unused.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
@ -302,7 +247,7 @@ class SuperAttention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||||
|
|
||||||
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
|
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
||||||
self.rotary_emb = SuperRotaryEmbedding(config=self.config)
|
self.rotary_emb = SuperRotaryEmbedding(config=self.config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -314,7 +259,7 @@ class SuperAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -349,7 +294,7 @@ class SuperAttention(nn.Module):
|
|||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
"removed and `position_embeddings` will be mandatory."
|
"removed and `position_embeddings` will be mandatory."
|
||||||
)
|
)
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
@ -422,7 +367,8 @@ class SuperFlashAttention2(SuperAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if isinstance(past_key_value, StaticCache):
|
if isinstance(past_key_value, StaticCache):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -449,7 +395,7 @@ class SuperFlashAttention2(SuperAttention):
|
|||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
"removed and `position_embeddings` will be mandatory."
|
"removed and `position_embeddings` will be mandatory."
|
||||||
)
|
)
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
@ -507,6 +453,7 @@ class SuperFlashAttention2(SuperAttention):
|
|||||||
sliding_window=getattr(self, "sliding_window", None),
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
is_causal=self.is_causal,
|
is_causal=self.is_causal,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
@ -535,7 +482,7 @@ class SuperSdpaAttention(SuperAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -569,7 +516,7 @@ class SuperSdpaAttention(SuperAttention):
|
|||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
"removed and `position_embeddings` will be mandatory."
|
"removed and `position_embeddings` will be mandatory."
|
||||||
)
|
)
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
@ -644,7 +591,7 @@ class SuperDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
@ -790,7 +737,8 @@ SUPER_INPUTS_DOCSTRING = r"""
|
|||||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
Two formats are allowed:
|
Two formats are allowed:
|
||||||
- a [`~cache_utils.Cache`] instance;
|
- a [`~cache_utils.Cache`] instance, see our
|
||||||
|
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
||||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
cache format.
|
cache format.
|
||||||
@ -916,10 +864,9 @@ class SuperModel(SuperPreTrainedModel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
if using_static_cache:
|
||||||
target_length = past_key_values.get_max_length()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@ -928,13 +875,12 @@ class SuperModel(SuperPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
target_length=target_length,
|
target_length=target_length,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
min_dtype=min_dtype,
|
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
)
|
)
|
||||||
@ -948,6 +894,63 @@ class SuperModel(SuperPreTrainedModel):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||||
|
`(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||||
|
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
@ -13,8 +13,5 @@ class RobertaEmbeddings(BertEmbeddings):
|
|||||||
|
|
||||||
|
|
||||||
class RobertaModel(BertModel):
|
class RobertaModel(BertModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config, add_pooling_layer=True):
|
||||||
super().__init__(self, config)
|
super().__init__(self, config)
|
||||||
# Error out here. Why? Because `RobertaEmbeddings` is defined but not used.
|
|
||||||
# no, because it's defined, and RobertaModel should use RobertaEmbedding
|
|
||||||
# here if initialized that way it won't use the new embedding.
|
|
||||||
|
@ -23,7 +23,6 @@ import math
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
@ -49,7 +48,10 @@ from ...utils import (
|
|||||||
from .configuration_gemma import GemmaConfig
|
from .configuration_gemma import GemmaConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "google/gemma-7b"
|
_CHECKPOINT_FOR_DOC = "google/gemma-7b"
|
||||||
|
_CONFIG_FOR_DOC = "GemmaConfig"
|
||||||
|
|
||||||
|
|
||||||
class GemmaRMSNorm(nn.Module):
|
class GemmaRMSNorm(nn.Module):
|
||||||
@ -72,9 +74,6 @@ class GemmaRMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaRotaryEmbedding(nn.Module):
|
class GemmaRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -624,9 +623,6 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
|||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "GemmaConfig"
|
|
||||||
|
|
||||||
|
|
||||||
GEMMA_INPUTS_DOCSTRING = r"""
|
GEMMA_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
@ -19,8 +19,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, HybridCache
|
from ...cache_utils import Cache, HybridCache
|
||||||
@ -40,6 +39,7 @@ from ...utils import (
|
|||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_flash_attn_2_available,
|
||||||
is_flash_attn_greater_or_equal,
|
is_flash_attn_greater_or_equal,
|
||||||
is_flash_attn_greater_or_equal_2_10,
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
logging,
|
logging,
|
||||||
@ -48,7 +48,15 @@ from ...utils import (
|
|||||||
from .configuration_gemma2 import Gemma2Config
|
from .configuration_gemma2 import Gemma2Config
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "google/gemma2-7b"
|
_CHECKPOINT_FOR_DOC = "google/gemma2-7b"
|
||||||
|
_CONFIG_FOR_DOC = "Gemma2Config"
|
||||||
|
|
||||||
|
|
||||||
class Gemma2RMSNorm(nn.Module):
|
class Gemma2RMSNorm(nn.Module):
|
||||||
@ -86,9 +94,6 @@ class Gemma2MLP(nn.Module):
|
|||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma2RotaryEmbedding(nn.Module):
|
class Gemma2RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -198,12 +203,12 @@ class Gemma2Attention(nn.Module):
|
|||||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
|
||||||
self.rotary_emb = Gemma2RotaryEmbedding(
|
self.rotary_emb = Gemma2RotaryEmbedding(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
)
|
)
|
||||||
|
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -495,12 +500,12 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||||
self.mlp = Gemma2MLP(config)
|
self.mlp = Gemma2MLP(config)
|
||||||
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.is_sliding = not bool(layer_idx % 2)
|
self.is_sliding = not bool(layer_idx % 2)
|
||||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.sliding_window = config.sliding_window
|
self.sliding_window = config.sliding_window
|
||||||
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -638,9 +643,6 @@ class Gemma2PreTrainedModel(PreTrainedModel):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "Gemma2Config"
|
|
||||||
|
|
||||||
|
|
||||||
GEMMA2_INPUTS_DOCSTRING = r"""
|
GEMMA2_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
@ -865,6 +867,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
@ -24,7 +24,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
@ -50,7 +49,10 @@ from ...utils import (
|
|||||||
from .configuration_glm import GlmConfig
|
from .configuration_glm import GlmConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b"
|
_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b"
|
||||||
|
_CONFIG_FOR_DOC = "GlmConfig"
|
||||||
|
|
||||||
|
|
||||||
class GlmRMSNorm(nn.Module):
|
class GlmRMSNorm(nn.Module):
|
||||||
@ -121,7 +123,16 @@ class GlmMLP(nn.Module):
|
|||||||
return self.down_proj(up_states)
|
return self.down_proj(up_states)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
@ -172,18 +183,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
||||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
"""
|
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return hidden_states
|
|
||||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
||||||
|
|
||||||
|
|
||||||
class GlmAttention(nn.Module):
|
class GlmAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
@ -608,9 +607,6 @@ class GlmPreTrainedModel(PreTrainedModel):
|
|||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "GlmConfig"
|
|
||||||
|
|
||||||
|
|
||||||
GLM_INPUTS_DOCSTRING = r"""
|
GLM_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
@ -24,7 +24,6 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Optional, Tuple, Union
|
from typing import Any, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
@ -347,104 +346,6 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
|
|
||||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
||||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
||||||
etc.)
|
|
||||||
|
|
||||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
||||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
||||||
and behavior.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
|
|
||||||
Initializing with a config file does not load the weights associated with the model, only the
|
|
||||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
|
|
||||||
Args:
|
|
||||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
||||||
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
|
|
||||||
[`InstructBlipVideoProcessor.__call__`] for details.
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to interpolate the pre-trained position encodings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
|
||||||
Args:
|
|
||||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
||||||
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
|
|
||||||
[`InstructBlipVideoProcessor.__call__`] for details.
|
|
||||||
|
|
||||||
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
|
||||||
to serve as text prompt, which the Q-Former model will encode.
|
|
||||||
|
|
||||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
|
||||||
details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
|
|
||||||
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
|
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
|
||||||
provided to serve as text prompt, which the language model can continue.
|
|
||||||
|
|
||||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
|
||||||
details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
|
|
||||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
|
|
||||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
||||||
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
|
||||||
encoder-decoder language model (like T5) is used.
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
|
||||||
|
|
||||||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
||||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
|
||||||
be used by default.
|
|
||||||
|
|
||||||
Only relevant in case an encoder-decoder language model (like T5) is used.
|
|
||||||
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to interpolate the pre-trained position encodings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoEncoder(nn.Module):
|
class InstructBlipVideoEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||||
@ -531,6 +432,24 @@ class InstructBlipVideoEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
|
||||||
|
[`InstructBlipVideoProcessor.__call__`] for details.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
|
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
config_class = InstructBlipVideoVisionConfig
|
config_class = InstructBlipVideoVisionConfig
|
||||||
@ -1268,6 +1187,87 @@ class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
|
||||||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
|
etc.)
|
||||||
|
|
||||||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||||
|
and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||||||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
|
||||||
|
[`InstructBlipVideoProcessor.__call__`] for details.
|
||||||
|
|
||||||
|
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
||||||
|
to serve as text prompt, which the Q-Former model will encode.
|
||||||
|
|
||||||
|
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||||
|
details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
|
||||||
|
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
||||||
|
provided to serve as text prompt, which the language model can continue.
|
||||||
|
|
||||||
|
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||||
|
details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||||
|
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
||||||
|
encoder-decoder language model (like T5) is used.
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||||
|
|
||||||
|
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||||
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||||
|
be used by default.
|
||||||
|
|
||||||
|
Only relevant in case an encoder-decoder language model (like T5) is used.
|
||||||
|
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
|
InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
|
||||||
|
@ -25,7 +25,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
@ -33,12 +32,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...image_processing_utils import select_best_resolution
|
from ...image_processing_utils import select_best_resolution
|
||||||
from ...modeling_outputs import ModelOutput
|
from ...modeling_outputs import ModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
add_start_docstrings,
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
logging,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from ..auto import AutoModel, AutoModelForCausalLM
|
from ..auto import AutoModel, AutoModelForCausalLM
|
||||||
from .configuration_llava_next_video import LlavaNextVideoConfig
|
from .configuration_llava_next_video import LlavaNextVideoConfig
|
||||||
|
|
||||||
@ -48,113 +42,6 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "LlavaNextVideoConfig"
|
_CONFIG_FOR_DOC = "LlavaNextVideoConfig"
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|
||||||
"""
|
|
||||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_size (`tuple`):
|
|
||||||
The size of the input image in the format (width, height).
|
|
||||||
grid_pinpoints (`List`):
|
|
||||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
|
||||||
of the form `(height, width)`.
|
|
||||||
patch_size (`int`):
|
|
||||||
The size of each image patch.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: The shape of the image patch grid in the format (width, height).
|
|
||||||
"""
|
|
||||||
if not isinstance(grid_pinpoints, list):
|
|
||||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
|
||||||
|
|
||||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
|
||||||
if not isinstance(image_size, (list, tuple)):
|
|
||||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
|
||||||
raise TypeError(
|
|
||||||
f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
|
|
||||||
)
|
|
||||||
image_size = image_size.tolist()
|
|
||||||
|
|
||||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
|
||||||
return height // patch_size, width // patch_size
|
|
||||||
|
|
||||||
|
|
||||||
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
|
||||||
"""
|
|
||||||
Calculate the number of patches after the preprocessing for images of any resolution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
|
|
||||||
The size of the input image in the format (height, width). ?
|
|
||||||
grid_pinpoints (`List`):
|
|
||||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
|
||||||
of the form `(height, width)`.
|
|
||||||
patch_size (`int`):
|
|
||||||
The size of each image patch.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: the number of patches
|
|
||||||
"""
|
|
||||||
if not isinstance(grid_pinpoints, list):
|
|
||||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
|
||||||
|
|
||||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
|
||||||
if not isinstance(image_size, (list, tuple)):
|
|
||||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
|
||||||
raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
|
|
||||||
image_size = image_size.tolist()
|
|
||||||
|
|
||||||
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
|
||||||
height, width = best_resolution
|
|
||||||
num_patches = 0
|
|
||||||
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
|
|
||||||
for i in range(0, height, patch_size):
|
|
||||||
for j in range(0, width, patch_size):
|
|
||||||
num_patches += 1
|
|
||||||
# add the base patch
|
|
||||||
num_patches += 1
|
|
||||||
return num_patches
|
|
||||||
|
|
||||||
|
|
||||||
def unpad_image(tensor, original_size):
|
|
||||||
"""
|
|
||||||
Unpads a PyTorch tensor of a padded and resized image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (`torch.Tensor`):
|
|
||||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
|
||||||
original_size (`tuple`):
|
|
||||||
The original size of the image (height, width).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.Tensor`: The unpadded image tensor.
|
|
||||||
"""
|
|
||||||
if not isinstance(original_size, (list, tuple)):
|
|
||||||
if not isinstance(original_size, (torch.Tensor, np.ndarray)):
|
|
||||||
raise TypeError(
|
|
||||||
f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
|
|
||||||
)
|
|
||||||
original_size = original_size.tolist()
|
|
||||||
original_height, original_width = original_size
|
|
||||||
current_height, current_width = tensor.shape[1:]
|
|
||||||
|
|
||||||
original_aspect_ratio = original_width / original_height
|
|
||||||
current_aspect_ratio = current_width / current_height
|
|
||||||
|
|
||||||
if original_aspect_ratio > current_aspect_ratio:
|
|
||||||
scale_factor = current_width / original_width
|
|
||||||
new_height = int(round(original_height * scale_factor, 7))
|
|
||||||
padding = (current_height - new_height) // 2
|
|
||||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
|
||||||
else:
|
|
||||||
scale_factor = current_height / original_height
|
|
||||||
new_width = int(round(original_width * scale_factor, 7))
|
|
||||||
padding = (current_width - new_width) // 2
|
|
||||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
|
||||||
|
|
||||||
return unpadded_tensor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
|
class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@ -304,6 +191,113 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
|
|||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
|
"""
|
||||||
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (`tuple`):
|
||||||
|
The size of the input image in the format (width, height).
|
||||||
|
grid_pinpoints (`List`):
|
||||||
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||||
|
of the form `(height, width)`.
|
||||||
|
patch_size (`int`):
|
||||||
|
The size of each image patch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: The shape of the image patch grid in the format (width, height).
|
||||||
|
"""
|
||||||
|
if not isinstance(grid_pinpoints, list):
|
||||||
|
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
||||||
|
|
||||||
|
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
||||||
|
if not isinstance(image_size, (list, tuple)):
|
||||||
|
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||||
|
raise TypeError(
|
||||||
|
f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
|
||||||
|
)
|
||||||
|
image_size = image_size.tolist()
|
||||||
|
|
||||||
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||||
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
||||||
|
"""
|
||||||
|
Calculate the number of patches after the preprocessing for images of any resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
|
||||||
|
The size of the input image in the format (height, width). ?
|
||||||
|
grid_pinpoints (`List`):
|
||||||
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||||
|
of the form `(height, width)`.
|
||||||
|
patch_size (`int`):
|
||||||
|
The size of each image patch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: the number of patches
|
||||||
|
"""
|
||||||
|
if not isinstance(grid_pinpoints, list):
|
||||||
|
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
||||||
|
|
||||||
|
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
||||||
|
if not isinstance(image_size, (list, tuple)):
|
||||||
|
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||||
|
raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
|
||||||
|
image_size = image_size.tolist()
|
||||||
|
|
||||||
|
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
||||||
|
height, width = best_resolution
|
||||||
|
num_patches = 0
|
||||||
|
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
|
||||||
|
for i in range(0, height, patch_size):
|
||||||
|
for j in range(0, width, patch_size):
|
||||||
|
num_patches += 1
|
||||||
|
# add the base patch
|
||||||
|
num_patches += 1
|
||||||
|
return num_patches
|
||||||
|
|
||||||
|
|
||||||
|
def unpad_image(tensor, original_size):
|
||||||
|
"""
|
||||||
|
Unpads a PyTorch tensor of a padded and resized image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`):
|
||||||
|
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||||
|
original_size (`tuple`):
|
||||||
|
The original size of the image (height, width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The unpadded image tensor.
|
||||||
|
"""
|
||||||
|
if not isinstance(original_size, (list, tuple)):
|
||||||
|
if not isinstance(original_size, (torch.Tensor, np.ndarray)):
|
||||||
|
raise TypeError(
|
||||||
|
f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
|
||||||
|
)
|
||||||
|
original_size = original_size.tolist()
|
||||||
|
original_height, original_width = original_size
|
||||||
|
current_height, current_width = tensor.shape[1:]
|
||||||
|
|
||||||
|
original_aspect_ratio = original_width / original_height
|
||||||
|
current_aspect_ratio = current_width / current_height
|
||||||
|
|
||||||
|
if original_aspect_ratio > current_aspect_ratio:
|
||||||
|
scale_factor = current_width / original_width
|
||||||
|
new_height = int(round(original_height * scale_factor, 7))
|
||||||
|
padding = (current_height - new_height) // 2
|
||||||
|
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||||
|
else:
|
||||||
|
scale_factor = current_height / original_height
|
||||||
|
new_width = int(round(original_width * scale_factor, 7))
|
||||||
|
padding = (current_width - new_width) // 2
|
||||||
|
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||||
|
|
||||||
|
return unpadded_tensor
|
||||||
|
|
||||||
|
|
||||||
LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
|
LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
@ -30,7 +30,6 @@ from transformers.models.llava_next.modeling_llava_next import (
|
|||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
)
|
||||||
from ..auto import CONFIG_MAPPING
|
from ..auto import CONFIG_MAPPING
|
||||||
|
|
||||||
@ -309,7 +308,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
|||||||
video_features = torch.split(video_features, frames, dim=0)
|
video_features = torch.split(video_features, frames, dim=0)
|
||||||
return video_features
|
return video_features
|
||||||
|
|
||||||
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class="LlavaNextVideoConfig")
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
@ -4,6 +4,8 @@ import glob
|
|||||||
import logging
|
import logging
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
|
from create_dependency_mapping import find_priority_list
|
||||||
|
|
||||||
# Console for rich printing
|
# Console for rich printing
|
||||||
from modular_model_converter import convert_modular_file
|
from modular_model_converter import convert_modular_file
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@ -69,7 +71,7 @@ if __name__ == "__main__":
|
|||||||
if args.files == ["all"]:
|
if args.files == ["all"]:
|
||||||
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||||
non_matching_files = 0
|
non_matching_files = 0
|
||||||
for modular_file_path in args.files:
|
for modular_file_path in find_priority_list(args.files):
|
||||||
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
||||||
|
|
||||||
if non_matching_files and not args.fix_and_overwrite:
|
if non_matching_files and not args.fix_and_overwrite:
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user