Modular transformers: modularity and inheritance for new model additions (#33248)

* update exampel

* update

* push the converted diff files for testing and ci

* correct one example

* fix class attributes and docstring

* nits

* oups

* fixed config!

* update

* nitd

* class attributes are not matched against the other, this is missing

* fixed overwriting self.xxx now onto the attributes I think

* partial fix, now order with docstring

* fix docstring order?

* more fixes

* update

* fix missing docstrings!

* examples don't all work yet

* fixup

* nit

* updated

* hick

* update

* delete

* update

* update

* update

* fix

* all default

* no local import

* fix more diff

* some fix related to "safe imports"

* push fixed

* add helper!

* style

* add a check

* all by default

* add the

* update

* FINALLY!

* nit

* fix config dependencies

* man that is it

* fix fix

* update diffs

* fix the last issue

* re-default to all

* alll the fixes

* nice

* fix properties vs setter

* fixup

* updates

* update dependencies

* make sure to install what needs to be installed

* fixup

* quick fix for now

* fix!

* fixup

* update

* update

* updates

* whitespaces

* nit

* fix

* simplify everything, and make it file agnostic (should work for image processors)

* style

* finish fixing all import issues

* fixup

* empty modeling should not be written!

* Add logic to find who depends on what

* update

* cleanup

* update

* update gemma to support positions

* some small nits

* this is the correct docstring for gemma2

* fix merging of docstrings

* update

* fixup

* update

* take doc into account

* styling

* update

* fix hidden activation

* more fixes

* final fixes!

* fixup

* fixup instruct  blip video

* update

* fix bugs

* align gemma2 with the rest as well

* updats

* revert

* update

* more reversiom

* grind

* more

* arf

* update

* order will matter

* finish del stuff

* update

* rename to modular

* fixup

* nits

* update makefile

* fixup

* update order of the checks!

* fix

* fix docstring that has a call inside

* fiix conversion check

* style

* add some initial documentation

* update

* update doc

* some fixup

* updates

* yups

* Mostly todo gimme a minut

* update

* fixup

* revert some stuff

* Review docs for the modular transformers (#33472)

Docs

* good update

* fixup

* mmm current updates lead to this code

* okay, this fixes it

* cool

* fixes

* update

* nit

* updates

* nits

* fix doc

* update

* revert bad changes

* update

* updates

* proper update

* update

* update?

* up

* update

* cool

* nits

* nits

* bon bon

* fix

* ?

* minimise changes

* update

* update

* update

* updates?

* fixed gemma2

* kind of a hack

* nits

* update

* remove `diffs` in favor of `modular`

* fix make fix copies

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Arthur 2024-09-24 15:54:07 +02:00 committed by GitHub
parent 75b7485cc7
commit 317e069ee7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 6515 additions and 789 deletions

View File

@ -137,7 +137,7 @@ jobs:
parallelism: 1
steps:
- checkout
- run: uv pip install -e .
- run: uv pip install -e ".[quality]"
- run:
name: Show installed libraries and their versions
command: pip freeze | tee installed.txt
@ -162,13 +162,14 @@ jobs:
parallelism: 1
steps:
- checkout
- run: uv pip install -e .
- run: uv pip install -e ".[quality]"
- run:
name: Show installed libraries and their versions
command: pip freeze | tee installed.txt
- store_artifacts:
path: ~/transformers/installed.txt
- run: python utils/check_copies.py
- run: python utils/check_modular_conversion.py
- run: python utils/check_table.py
- run: python utils/check_dummies.py
- run: python utils/check_repo.py

View File

@ -36,6 +36,7 @@ autogenerate_code: deps_table_update
repo-consistency:
python utils/check_copies.py
python utils/check_modular_conversion.py
python utils/check_table.py
python utils/check_dummies.py
python utils/check_repo.py
@ -80,6 +81,7 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
fix-copies:
python utils/check_copies.py --fix_and_overwrite
python utils/check_modular_conversion.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
python utils/check_doctest_list.py --fix_and_overwrite

View File

@ -5,6 +5,8 @@
title: Quick tour
- local: installation
title: Installation
- local: add_new_model
title: Adding a new model to `transformers`
title: Get started
- sections:
- local: pipeline_tutorial
@ -149,6 +151,8 @@
title: Interoperability with GGUF files
- local: tiktoken
title: Interoperability with TikToken files
- local: modular_transformers
title: Modularity in `transformers`
title: Developer guides
- sections:
- local: quantization/overview

View File

@ -0,0 +1,121 @@
# Modular transformers
`transformers` is an opinionated framework; our philosophy is defined in the following [conceptual guide](./philosophy).
The core of that philosophy is exemplified by the [single model, single file](https://huggingface.co/blog/transformers-design-philosophy)
aspect of the library. This component's downside is that it limits the inheritance and importability of components from
files to others in the toolkit.
As a result, model components tend to be repeated across many files. There are as many attention layers defined
in `transformers` as there are models, and a significant number of those are identical to each other.
The unfortunate consequence is that independent implementations tend to diverge as fixes and changes get applied
to specific parts of the code.
In order to balance this issue, we introduced the concept of "copies" across the library. By adding a comment indicating
that code is a copy of another, we can enforce through CI and local commands that copies do not diverge. However,
while the complexity is low, this is often quite tedious to do.
And, finally, this contributes to adding a significant overhead to contributing models which we would like to remove.
This approach often requires model contributions to add modeling code (~1k lines), processor (~500 lines), tests, docs,
etc. Model contribution PRs rarely add less than 3-5k lines of code, with much of this code being boilerplate.
This raises the bar for contributions, and with Modular Transformers, we're aiming to lower the bar to a much more
acceptable point.
## What is it?
Modular Transformers introduces the concept of a "modular" file to a model folder. This modular file accepts code
that isn't typically accepted in modeling/processing files, as it allows importing from neighbouring models as well
as inheritance from classes to others.
This modular file defines models, processors, and the configuration class that would otherwise be defined in their
respective modules.
Finally, this feature introduces a new `linter` which will "unravel" the modular file into the "single model, single
file" directory structure. These files will get auto-generated every time the script is run; reducing the required
contributions to the modular file, and therefore only to the changes between the contributed model and others.
Model users will end up importing and using the single-file interface, so no change is expected here. Doing this, we
hope to combine the best of both worlds: enabling simple contributions while sticking to our philosophy.
This is therefore a replacement for the `# Copied from` markers, and previously contributed models can be expected to
be moved to the new Modular Transformers format in the coming months.
### Details
The "linter", which unravels the inheritance and creates all single-files from the modular file, will flatten the
inheritance while trying to be invisible to Python users. At this time, the linter flattens a **single** level of
inheritance.
For example:
- If a configuration class inherits from another and adds/deletes an argument, the generated file will either directly
reference it (in case of addition) or completely remove it (in case of deletion).
- If a class inherits from another, for example: class GemmaModel(LlamaModel):, dependencies are automatically
inferred. All submodules will be automatically inferred from the superclass.
You should be able to write everything (the tokenizer, the image processor, the model, the config) in this `modular`
file, and the corresponding files will be created for you.
### Enforcement
[TODO] We are introducing a new test, that makes sure the generated content matches what is present in the `modular_xxxx.py`
### Examples
Here is a quick example with BERT and RoBERTa. The two models are intimately related: their modeling implementation
differs solely by a change in the embedding layer.
Instead of redefining the model entirely, here is what the `modular_roberta.py` file looks like for the modeling &
configuration classes (for the sake of the example, the tokenizer is ignored at this time as very different).
```python
from torch import nn
from ..bert.configuration_bert import BertConfig
from ..bert.modeling_bert import (
BertModel,
BertEmbeddings,
BertForMaskedLM
)
# The RoBERTa config is identical to BERT's config
class RobertaConfig(BertConfig):
model_type = 'roberta'
# We redefine the embeddings here to highlight the padding ID difference, and we redefine the position embeddings
class RobertaEmbeddings(BertEmbeddings):
def __init__(self, config):
super().__init__(config())
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
# The RoBERTa model is identical to the BERT model, except for the embedding layer.
# We redefine the embeddings above, so here there is no need to do additional work
class RobertaModel(BertModel):
def __init__(self, config):
super().__init__(config)
self.embeddings = RobertaEmbeddings(config)
# The heads now only need to redefine the model inside to the correct `RobertaModel`
class RobertaForMaskedLM(BertForMaskedLM):
def __init__(self, config):
super().__init__(config)
self.model = RobertaModel(config)
```
Note that if you do not use the dependency that you defined, you will have the following error:
```bash
ValueError: You defined `RobertaEmbeddings` in the modular_roberta.py, it should be used
when you define `BertModel`, as it is one of it's direct dependencies. Make sure
you use it in the `__init__` function.
```
Additionally, you may find a list of examples here:
## What it is not
It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual.

View File

@ -1,20 +0,0 @@
# Using the `diff_converter` linter
`pip install libcst` is a must!
# `sh examples/diff-conversion/convert_examples.sh` to get the converted outputs
The diff converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular `diff` file like `diff_gemma.py` into a `single model single file`.
Examples of possible usage are available in the `examples/diff-conversion`, or `diff_gemma` for a full model usage.
`python utils/diff_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model2.py"`
## How it works
We use the `libcst` parser to produce an AST representation of the `diff_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the difference dependencies.
The code from the `diff` file and the class dependency mapping are "merged" to produce the single model single file.
We use ruff to automatically remove the potential duplicate imports.
## Why we use libcst instead of the native AST?
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`

View File

@ -0,0 +1,20 @@
# Using the `modular_converter` linter
`pip install libcst` is a must!
# `sh examples/modular-transformers/convert_examples.sh` to get the converted outputs
The modular converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular file like `modular_gemma.py` into a `single model single file`.
Examples of possible usage are available in the `examples/modular-transformers`, or `modular_gemma` for a full model usage.
`python utils/modular_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/modular-transformers/modular_my_new_model2.py"`
## How it works
We use the `libcst` parser to produce an AST representation of the `modular_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the modularerence dependencies.
The code from the `modular` file and the class dependency mapping are "merged" to produce the single model single file.
We use ruff to automatically remove the potential duplicate imports.
## Why we use libcst instead of the native AST?
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`

View File

@ -0,0 +1,196 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
class MyNewModelConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MyNewModelModel`]. It is used to instantiate an MyNewModel
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MyNewModel-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the MyNewModel model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MyNewModelModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
MyNewModel 2 up to 4096, CodeMyNewModel up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'my_new_model3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'my_new_model3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'my_new_model3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'my_new_model3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_heads
new_param (`int`, *optional*, defaults to `False`):
A fun new parameter
```python
>>> from transformers import MyNewModelModel, MyNewModelConfig
>>> # Initializing a MyNewModel my_new_model-7b style configuration
>>> configuration = MyNewModelConfig()
>>> # Initializing a model from the my_new_model-7b style configuration
>>> model = MyNewModelModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "my_new_model"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=True,
head_dim=None,
new_param=0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.mlp_bias = mlp_bias
self.new_param = new_param

View File

@ -0,0 +1,97 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
class MyNewModel2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "my_new_model2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,134 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Example where we only want to overwrite the defaults of an init
from transformers import PretrainedConfig
class NewModelConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`NewModelModel`]. It is used to instantiate an NewModel
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the NewModel-7B.
e.g. [google/new_model-7b](https://huggingface.co/google/new_model-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the NewModel model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`NewModelModel`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The legacy activation function. It is overwritten by the `hidden_activation`.
hidden_activation (`str` or `function`, *optional*):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import NewModelModel, NewModelConfig
>>> # Initializing a NewModel new_model-7b style configuration
>>> configuration = NewModelConfig()
>>> # Initializing a model from the new_model-7b style configuration
>>> model = NewModelModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "new_model"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=256030,
hidden_size=64,
intermediate_size=90,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=1500,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def num_heads(self):
return self.num_attention_heads

View File

@ -1,7 +1,7 @@
#!/bin/bash
# Iterate over each file in the current directory
for file in examples/diff-conversion/diff_*; do
for file in examples/modular-transformers/modular_*; do
# Check if it's a regular file
if [ -f "$file" ]; then
# Call the Python script with the file name as an argument

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,953 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from .configuration_super import SuperConfig
logger = logging.get_logger(__name__)
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class SuperRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
SuperRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class SuperRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[SuperConfig] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`SuperRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.45"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SuperMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class SuperAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: SuperConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
self.rotary_emb = SuperRotaryEmbedding(config=self.config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class SuperFlashAttention2(SuperAttention):
"""
Super flash attention module. This module inherits from `SuperAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (SuperRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class SuperSdpaAttention(SuperAttention):
"""
Super attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`SuperAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from SuperAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"SuperModel is using SuperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
SUPER_ATTENTION_CLASSES = {
"eager": SuperAttention,
"flash_attention_2": SuperFlashAttention2,
"sdpa": SuperSdpaAttention,
}
class SuperDecoderLayer(nn.Module):
def __init__(self, config: SuperConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = SUPER_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = SuperMLP(config)
self.input_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
SUPER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SuperConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Super Model outputting raw hidden-states without any specific head on top.",
SUPER_START_DOCSTRING,
)
class SuperPreTrainedModel(PreTrainedModel):
config_class = SuperConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["SuperDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
SUPER_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare Super Model outputting raw hidden-states without any specific head on top.",
SUPER_START_DOCSTRING,
)
class SuperModel(SuperPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuperDecoderLayer`]
Args:
config: SuperConfig
"""
def __init__(self, config: SuperConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[SuperDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = SuperRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(SUPER_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
out = super().forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
)
out.logits *= 2**4
return out
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask

View File

@ -3,10 +3,11 @@ from typing import List, Optional, Tuple, Union
import torch
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
from ...cache_utils import Cache
def _pre_process_input(input_ids):
print(log(input_ids))

View File

@ -0,0 +1,27 @@
from typing import List, Optional, Tuple, Union
import torch
from transformers.models.bert.modeling_bert import BertModel
from ...modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
class DummyBertModel(BertModel):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
return super().forward(input_ids)

View File

@ -5,10 +5,11 @@ from transformers.models.llama.configuration_llama import LlamaConfig
# here there is no `ARG` so we are gonna take parent doc
class MyNewModelConfig(LlamaConfig):
r"""
mlp_bias (`bool`, *optional*, defaults to `False`)
new_param (`int`, *optional*, defaults to `False`):
A fun new parameter
"""
def __init__(self, mlp_bias=True, new_param=0, **super_kwargs):
super().__init__(self, **super_kwargs)
self.mlp_bias = mlp_bias
self.new_param = new_param
super().__init__(self, **super_kwargs)

View File

@ -26,5 +26,10 @@ class NewModelConfig(GemmaConfig):
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
super().__init__(self)
super().__init__(self, **kwargs)
@property
def num_heads(self):
return self.num_attention_heads

View File

@ -0,0 +1,20 @@
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel
class RobertaEmbeddings(BertEmbeddings):
def __init__(self, config):
super().__init__(config)
self.pad_token_id = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, config.pad_token_id
)
class RobertaModel(BertModel):
def __init__(self, config):
super().__init__(self, config)
# Error out here. Why? Because `RobertaEmbeddings` is defined but not used.
# no, because it's defined, and RobertaModel should use RobertaEmbedding
# here if initialized that way it won't use the new embedding.

View File

@ -2,10 +2,11 @@ from typing import List, Optional, Tuple, Union
import torch
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
from ...cache_utils import Cache
# example where we need some deps and some functions
class SuperModel(LlamaModel):

View File

@ -192,6 +192,8 @@ _deps = [
"urllib3<2.0.0",
"uvicorn",
"pytest-rich",
"libcst",
"rich",
]
@ -345,7 +347,7 @@ extras["testing"] = (
extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"]
extras["ruff"] = deps_list("ruff")
extras["quality"] = deps_list("datasets", "isort", "ruff", "GitPython", "urllib3")
extras["quality"] = deps_list("datasets", "isort", "ruff", "GitPython", "urllib3", "libcst", "rich")
extras["all"] = (
extras["tf"]

View File

@ -97,4 +97,6 @@ deps = {
"urllib3": "urllib3<2.0.0",
"uvicorn": "uvicorn",
"pytest-rich": "pytest-rich",
"libcst": "libcst",
"rich": "rich",
}

View File

@ -1,8 +1,8 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# This file was automatically generated from <path_to_modular_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
@ -21,7 +21,7 @@
# limitations under the License.
from transformers import PretrainedConfig
from ...configuration_utils import PretrainedConfig
class GemmaConfig(PretrainedConfig):

View File

@ -1,8 +1,8 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# This file was automatically generated from <path_to_modular_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
@ -39,7 +39,6 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
@ -51,63 +50,6 @@ from ...utils import (
from .configuration_gemma import GemmaConfig
logger = logging.get_logger(__name__)
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -128,7 +70,7 @@ class GemmaRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
logger = logging.get_logger(__name__)
class GemmaRotaryEmbedding(nn.Module):
@ -159,30 +101,6 @@ class GemmaRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_activation is None:
logger.warning_once(
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
"`config.hidden_activation` if you want to override this behaviour.\n"
"See https://github.com/huggingface/transformers/pull/29402 for more details."
)
config.hidden_activation = "gelu_pytorch_tanh"
hidden_activation = config.hidden_activation
self.act_fn = ACT2FN[hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
@ -212,6 +130,30 @@ class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
return cos, sin
class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_activation is None:
logger.warning_once(
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
"`config.hidden_activation` if you want to override this behaviour.\n"
"See https://github.com/huggingface/transformers/pull/29402 for more details."
)
config.hidden_activation = "gelu_pytorch_tanh"
hidden_activation = config.hidden_activation
self.act_fn = ACT2FN[hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@ -358,6 +300,94 @@ class GemmaAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class GemmaSdpaAttention(GemmaAttention):
"""
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from GemmaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class GemmaFlashAttention2(GemmaAttention):
"""
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
@ -458,7 +488,6 @@ class GemmaFlashAttention2(GemmaAttention):
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
@ -468,92 +497,57 @@ class GemmaFlashAttention2(GemmaAttention):
return attn_output, attn_weights, past_key_value
class GemmaSdpaAttention(GemmaAttention):
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
# Adapted from GemmaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return causal_mask
GEMMA_ATTENTION_CLASSES = {
@ -567,9 +561,7 @@ class GemmaDecoderLayer(nn.Module):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = GemmaMLP(config)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -830,9 +822,9 @@ class GemmaModel(GemmaPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
return_legacy_cache = True # noqa: F841
if past_key_values is None:
past_key_values = DynamicCache()
else:
@ -975,6 +967,7 @@ class GemmaModel(GemmaPreTrainedModel):
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
@ -1149,6 +1142,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
@ -1230,7 +1224,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,

View File

@ -21,8 +21,15 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PretrainedConfig
from transformers.models.llama.modeling_llama import (
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import is_torchdynamo_compiling, logging
from ..llama.modeling_llama import (
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaForSequenceClassification,
@ -32,14 +39,6 @@ from transformers.models.llama.modeling_llama import (
repeat_kv,
)
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import CausalLMOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import logging
logger = logging.get_logger(__name__)
@ -216,6 +215,35 @@ class GemmaRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def forward(self, x, position_ids):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids)
return cos, sin
class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
"""GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def forward(self, x, position_ids):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
cos, sin = super().forward(x, position_ids)
return cos, sin
class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()
@ -340,8 +368,95 @@ class GemmaAttention(nn.Module):
return attn_output, attn_weights, past_key_value
# TODO felix: does this inheritance really work out in the end to GemmaFlashAttention2 inheriting form GemmaAttention?
class GemmaFlashAttention2(LlamaFlashAttention2):
class GemmaSdpaAttention(GemmaAttention):
"""
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from GemmaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class GemmaFlashAttention2(LlamaFlashAttention2, GemmaAttention):
"""
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@ -427,12 +542,12 @@ class GemmaFlashAttention2(LlamaFlashAttention2):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
@ -442,7 +557,95 @@ class GemmaFlashAttention2(LlamaFlashAttention2):
return attn_output, attn_weights, past_key_value
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaAttention,
"flash_attention_2": GemmaFlashAttention2,
"sdpa": GemmaSdpaAttention,
}
class GemmaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__(config)
self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = GemmaMLP(config)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class GemmaModel(LlamaModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet!
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
@ -455,7 +658,7 @@ class GemmaModel(LlamaModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -513,22 +716,72 @@ class GemmaModel(LlamaModel):
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
return super().forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
use_cache,
cache_position,
input_ids=None,
inputs_embeds=hidden_states,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# Example where we ony modify the docstring and call super
class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
class GemmaForCausalLM(LlamaForCausalLM):
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
@ -542,18 +795,9 @@ class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
@ -589,10 +833,18 @@ class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -618,8 +870,14 @@ class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
class GemmaForSequenceClassification(LlamaForSequenceClassification):
pass
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)
self.post_init()
class GemmaForTokenClassification(LlamaForTokenClassification):
pass
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)
self.post_init()

View File

@ -1,8 +1,8 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# This file was automatically generated from <path_to_modular_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
@ -19,7 +19,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import PretrainedConfig
from ...configuration_utils import PretrainedConfig
class Gemma2Config(PretrainedConfig):
@ -53,7 +55,8 @@ class Gemma2Config(PretrainedConfig):
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder.
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
@ -77,16 +80,17 @@ class Gemma2Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
```python
>>> from transformers import Gemma2Model, Gemma2Config
>>> # Initializing a Gemma2 gemma2-9b style configuration
>>> # Initializing a Gemma2 gemma2-7b style configuration
>>> configuration = Gemma2Config()
>>> # Initializing a model from the gemma2-9b style configuration
>>> # Initializing a model from the gemma2-7b style configuration
>>> model = Gemma2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
@ -94,6 +98,7 @@ class Gemma2Config(PretrainedConfig):
model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
cache_implementation = "hybrid"
def __init__(
self,
@ -116,12 +121,19 @@ class Gemma2Config(PretrainedConfig):
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
query_pre_attn_scalar=224,
sliding_window=4096,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
@ -130,23 +142,14 @@ class Gemma2Config(PretrainedConfig):
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.attn_logit_softcapping = attn_logit_softcapping
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.final_logit_softcapping = final_logit_softcapping
self.hidden_activation = hidden_activation
self.query_pre_attn_scalar = query_pre_attn_scalar
self.sliding_window = sliding_window
self.cache_implementation = "hybrid"
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping

View File

@ -1,8 +1,8 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# This file was automatically generated from <path_to_modular_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
@ -22,13 +22,14 @@
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -39,7 +40,6 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
@ -49,66 +49,6 @@ from ...utils import (
from .configuration_gemma2 import Gemma2Config
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class Gemma2RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -129,6 +69,24 @@ class Gemma2RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class Gemma2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
logger = logging.get_logger(__name__)
class Gemma2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@ -191,21 +149,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
class Gemma2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@ -253,12 +196,12 @@ class Gemma2Attention(nn.Module):
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
self.rotary_emb = Gemma2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
def forward(
self,
@ -502,9 +445,11 @@ class Gemma2SdpaAttention(Gemma2Attention):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
@ -515,6 +460,7 @@ class Gemma2SdpaAttention(Gemma2Attention):
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
@ -533,6 +479,59 @@ class Gemma2SdpaAttention(Gemma2Attention):
return attn_output, None, past_key_value
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
GEMMA2_ATTENTION_CLASSES = {
"eager": Gemma2Attention,
"flash_attention_2": Gemma2FlashAttention2,
@ -543,19 +542,16 @@ GEMMA2_ATTENTION_CLASSES = {
class Gemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.config = config
self.is_sliding = not bool(layer_idx % 2)
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
@ -567,6 +563,25 @@ class Gemma2DecoderLayer(nn.Module):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
@ -580,6 +595,7 @@ class Gemma2DecoderLayer(nn.Module):
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@ -711,13 +727,20 @@ GEMMA2_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`HybridCache`, *optional*):
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Gemma 2 uses a unique cache class, [`HybridCache`], and does not guarantee full compatibility with other
cache classes.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
@ -812,8 +835,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Instantiate an empty cache if needed.
if use_cache and past_key_values is None:
if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
@ -828,6 +850,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@ -844,6 +867,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
@ -880,7 +904,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
@ -1009,6 +1032,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
if self.training and self.config._attn_implementation != "eager":
logger.warning_once(
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
@ -1187,10 +1211,10 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,

View File

@ -13,30 +13,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.models.gemma.configuration_gemma import GemmaConfig
from transformers.models.gemma.modeling_gemma import (
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...configuration_utils import PretrainedConfig
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
)
from ..gemma.modeling_gemma import (
GemmaAttention,
GemmaDecoderLayer,
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaPreTrainedModel,
GemmaRMSNorm,
_prepare_4d_causal_attention_mask_with_cache_position,
apply_rotary_pos_emb,
repeat_kv,
)
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -45,33 +56,230 @@ if is_flash_attn_2_available():
logger = logging.get_logger(__name__)
class Gemma2Config(GemmaConfig):
cache_implementation = "hybrid" # TODO this is not properly ported, but cls attr is better
class Gemma2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma2-7B.
e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Gemma2Model`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
```python
>>> from transformers import Gemma2Model, Gemma2Config
>>> # Initializing a Gemma2 gemma2-7b style configuration
>>> configuration = Gemma2Config()
>>> # Initializing a model from the gemma2-7b style configuration
>>> model = Gemma2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
cache_implementation = "hybrid"
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
query_pre_attn_scalar=224,
sliding_window=4096,
final_logit_softcapping=30.0,
**super_kwargs,
attn_logit_softcapping=50.0,
**kwargs,
):
super().__init__(self, **super_kwargs)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.query_pre_attn_scalar = query_pre_attn_scalar
self.sliding_window = sliding_window
self.cache_implementation = "hybrid"
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
class Gemma2RMSNorm(GemmaRMSNorm):
pass
class Gemma2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class Gemma2Attention(GemmaAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.scaling = config.query_pre_attn_scalar**-0.5
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Gemma2FlashAttention2(Gemma2Attention):
@ -119,9 +327,19 @@ class Gemma2FlashAttention2(Gemma2Attention):
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if attention_mask is not None:
seq_len = attention_mask.shape[1]
key_states = key_states[:, :, :seq_len]
value_states = value_states[:, :, :seq_len]
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
@ -156,7 +374,6 @@ class Gemma2FlashAttention2(Gemma2Attention):
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING
attn_output = _flash_attention_forward(
query_states,
key_states,
@ -166,7 +383,9 @@ class Gemma2FlashAttention2(Gemma2Attention):
dropout=dropout_rate,
softmax_scale=self.scaling,
is_causal=self.is_causal,
sliding_window=self.sliding_window,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@ -227,7 +446,12 @@ class Gemma2SdpaAttention(Gemma2Attention):
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
@ -269,8 +493,9 @@ class Gemma2SdpaAttention(Gemma2Attention):
class Gemma2DecoderLayer(GemmaDecoderLayer):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__(config, layer_idx)
self.is_sliding = bool(layer_idx % 2)
self.config = config
self.is_sliding = not bool(layer_idx % 2)
self.mlp = Gemma2MLP(config)
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
@ -286,11 +511,18 @@ class Gemma2DecoderLayer(GemmaDecoderLayer):
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
attention_mask = attention_mask * torch.tril(
torch.ones_like(attention_mask), diagonal=(self.sliding_window - cache_position[-1])
)
if cache_position[0] > 0:
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
residual = hidden_states
@ -326,13 +558,38 @@ class Gemma2DecoderLayer(GemmaDecoderLayer):
return outputs
class Gemma2Model(GemmaModel):
class Gemma2PreTrainedModel(GemmaPreTrainedModel):
_supports_quantized_cache = False
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
"""
Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
SDPA reduces the model performance on Gemma2 because of the logits softcapping.
"""
config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
# if using the default path -> swap sdpa by eager
if not hard_check_only and config._attn_implementation == "sdpa":
config._attn_implementation = "eager"
return config
class Gemma2Model(GemmaModel, Gemma2PreTrainedModel):
def __init__(self, config: Gemma2Config):
super().__init__(config)
self.layers = nn.ModuleList(
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -361,8 +618,21 @@ class Gemma2Model(GemmaModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)
if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@ -437,50 +707,50 @@ class Gemma2Model(GemmaModel):
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
past_key_values: HybridCache,
output_attentions: bool,
):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None:
if isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_length()
else:
target_length = attention_mask.shape[-1]
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
if attention_mask is not None and attention_mask.dim() == 4:
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
return causal_mask
class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
class Gemma2ForCausalLM(GemmaForCausalLM):
def __init__(self, config):
super().__init__(config)
self.model = Gemma2Model(config)
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -488,18 +758,9 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
@ -514,12 +775,17 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
if self.training and self.config._attn_implementation != "eager":
logger.warning_once(
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
@ -535,15 +801,23 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
# TODO: remove the float() operation in v4.46
logits = logits.float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -567,10 +841,94 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
# which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
pass
def __init__(self, config):
super().__init__(config)
self.model = Gemma2Model(config)
self.post_init()
class Gemma2ForTokenClassification(GemmaForTokenClassification):
pass
def __init__(self, config):
super().__init__(config)
self.model = Gemma2Model(config)
self.post_init()

View File

@ -1,8 +1,8 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# This file was automatically generated from <path_to_modular_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
@ -24,9 +24,7 @@ from typing import Union
from ...configuration_utils import PretrainedConfig
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from ...utils import (
logging,
)
from ...utils import logging
from ..auto import CONFIG_MAPPING
@ -36,8 +34,8 @@ logger = logging.get_logger(__name__)
class InstructBlipVideoVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InstructBlipVideoVisionModel`]. It is used to
instantiate a Instructblipvideo vision encoder according to the specified arguments, defining the model architecture.
Instantiating a configuration defaults will yield a similar configuration to that of the Instructblipvideo
instantiate a InstructBlipVideo vision encoder according to the specified arguments, defining the model architecture.
Instantiating a configuration defaults will yield a similar configuration to that of the InstructBlipVideo
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
@ -58,7 +56,7 @@ class InstructBlipVideoVisionConfig(PretrainedConfig):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. to 1e-5): The epsilon used by the layer
`"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. to 1e-5): The epsilon used by the layer
normalization layers.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
@ -137,9 +135,9 @@ class InstructBlipVideoVisionConfig(PretrainedConfig):
class InstructBlipVideoQFormerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InstructBlipVideoQFormerModel`]. It is used to
instantiate a Instructblipvideo Querying Transformer (Q-Former) model according to the specified arguments, defining the
instantiate a InstructBlipVideo Querying Transformer (Q-Former) model according to the specified arguments, defining the
model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the Instructblipvideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
the InstructBlipVideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
Read the documentation from [`PretrainedConfig`] for more information.
@ -189,7 +187,7 @@ class InstructBlipVideoQFormerConfig(PretrainedConfig):
```python
>>> from transformers import InstructBlipVideoQFormerConfig, InstructBlipVideoQFormerModel
>>> # Initializing a Instructblipvideo Salesforce/instruct-blip-flan-t5 style configuration
>>> # Initializing a InstructBlipVideo Salesforce/instruct-blip-flan-t5 style configuration
>>> configuration = InstructBlipVideoQFormerConfig()
>>> # Initializing a model (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
@ -360,7 +358,7 @@ class InstructBlipVideoConfig(PretrainedConfig):
**kwargs,
):
r"""
Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a Instructblipvideo vision model, Q-Former and
Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
language model configurations.
Returns:

View File

@ -1,8 +1,8 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# This file was automatically generated from <path_to_modular_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
@ -19,7 +19,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
@ -354,6 +353,21 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
@ -371,6 +385,71 @@ INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
Whether to interpolate the pre-trained position encodings.
"""
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
[`InstructBlipVideoProcessor.__call__`] for details.
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
to serve as text prompt, which the Q-Former model will encode.
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
provided to serve as text prompt, which the language model can continue.
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
encoder-decoder language model (like T5) is used.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
Only relevant in case an encoder-decoder language model (like T5) is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlipVideo
class InstructBlipVideoEncoder(nn.Module):
@ -459,87 +538,6 @@ class InstructBlipVideoEncoder(nn.Module):
)
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
[`InstructBlipVideoProcessor.__call__`] for details.
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
to serve as text prompt, which the Q-Former model will encode.
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
provided to serve as text prompt, which the language model can continue.
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
encoder-decoder language model (like T5) is used.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
Only relevant in case an encoder-decoder language model (like T5) is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlipVideo, BLIP->INSTRUCTBLIPVIDEO
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
main_input_name = "pixel_values"
@ -1089,7 +1087,7 @@ class InstructBlipVideoQFormerEmbeddings(nn.Module):
class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
"""
Querying Transformer (Q-Former), used in Instructblipvideo. Slightly modified from BLIP-2 as it also takes the
Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the
instruction as input.
"""
@ -1285,7 +1283,7 @@ class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
@add_start_docstrings(
"""
Instructblipvideo Model for generating text given an image and an optional text prompt. The model consists of a vision
InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
encoder, Querying Transformer (Q-Former) and a language model.
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
@ -1358,7 +1356,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
hf_device_map = self.hf_device_map
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
# warn users about unexpected behavior when using multi-GPU + Instructblipvideo + `accelerate`.
# warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
logger.warning(
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
@ -1505,7 +1503,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
@ -1584,7 +1581,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
"""
r"""
Overrides `generate` function to be able to use the model as a conditional generator.
Args:

View File

@ -21,32 +21,18 @@ import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.models.instructblip.configuration_instructblip import (
InstructBlipConfig,
InstructBlipQFormerConfig,
InstructBlipVisionConfig,
)
from transformers.models.instructblip.modeling_instructblip import (
InstructBlipAttention,
InstructBlipEncoder,
InstructBlipEncoderLayer,
InstructBlipForConditionalGeneration,
InstructBlipForConditionalGenerationModelOutput,
InstructBlipMLP,
InstructBlipPreTrainedModel,
InstructBlipQFormerAttention,
InstructBlipQFormerEmbeddings,
InstructBlipQFormerEncoder,
InstructBlipQFormerIntermediate,
InstructBlipQFormerLayer,
InstructBlipQFormerModel,
InstructBlipQFormerOutput,
InstructBlipQFormerSelfOutput,
InstructBlipVisionEmbeddings,
InstructBlipVisionModel,
)
from ...generation import GenerationMixin
from ...configuration_utils import PretrainedConfig
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
@ -60,8 +46,124 @@ class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig):
pass
class InstructBlipVideoConfig(InstructBlipConfig):
pass
class InstructBlipVideoConfig(PretrainedConfig):
r"""
[`InstructBlipVideoConfig`] is the configuration class to store the configuration of a
[`InstructBlipVideoForConditionalGeneration`]. It is used to instantiate a Instructblipvideo model according to the specified
arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
the defaults will yield a similar configuration to that of the Instructblipvideo
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`InstructBlipVideoVisionConfig`].
qformer_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
num_query_tokens (`int`, *optional*, defaults to 32):
The number of query tokens passed through the Transformer.
video_token_index (`int`, *optional*):
Token index of special video token.
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```python
>>> from transformers import (
... InstructBlipVideoVisionConfig,
... InstructBlipVideoQFormerConfig,
... OPTConfig,
... InstructBlipVideoConfig,
... InstructBlipVideoForConditionalGeneration,
... )
>>> # Initializing a InstructBlipVideoConfig with Salesforce/instruct-blip-flan-t5 style configuration
>>> configuration = InstructBlipVideoConfig()
>>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
>>> model = InstructBlipVideoForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PretrainedConfig
>>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
>>> vision_config = InstructBlipVideoVisionConfig()
>>> qformer_config = InstructBlipVideoQFormerConfig()
>>> text_config = OPTConfig()
>>> config = InstructBlipVideoConfig.from_text_vision_configs(vision_config, qformer_config, text_config)
```"""
model_type = "instructblipvideo"
def __init__(
self,
vision_config=None,
qformer_config=None,
text_config=None,
num_query_tokens=32,
video_token_index=None,
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {}
logger.info("vision_config is None. initializing the InstructBlipVideoVisionConfig with default values.")
if qformer_config is None:
qformer_config = {}
logger.info("qformer_config is None. Initializing the InstructBlipVideoQFormerConfig with default values.")
if text_config is None:
text_config = {}
logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
self.vision_config = InstructBlipVideoVisionConfig(**vision_config)
self.qformer_config = InstructBlipVideoQFormerConfig(**qformer_config)
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder
self.num_query_tokens = num_query_tokens
self.video_token_index = video_token_index
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
self.initializer_factor = 1.0
self.initializer_range = 0.02
@classmethod
def from_vision_qformer_text_configs(
cls,
vision_config: InstructBlipVideoVisionConfig,
qformer_config: InstructBlipVideoQFormerConfig,
text_config: PretrainedConfig,
**kwargs,
):
r"""
Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
language model configurations.
Returns:
[`InstructBlipVideoConfig`]: An instance of a configuration object
"""
return cls(
vision_config=vision_config.to_dict(),
qformer_config=qformer_config.to_dict(),
text_config=text_config.to_dict(),
**kwargs,
)
@dataclass
@ -69,67 +171,7 @@ class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForCondit
pass
class InstructBlipVideoVisionEmbeddings(InstructBlipVisionEmbeddings):
pass
class InstructBlipVideoAttention(InstructBlipAttention):
pass
class InstructBlipVideoMLP(InstructBlipMLP):
pass
class InstructBlipVideoEncoderLayer(InstructBlipEncoderLayer):
pass
class InstructBlipVideoPreTrainedModel(InstructBlipPreTrainedModel):
pass
class InstructBlipVideoEncoder(InstructBlipEncoder):
pass
class InstructBlipVideoVisionModel(InstructBlipVisionModel):
pass
class InstructBlipVideoQFormerSelfOutput(InstructBlipQFormerSelfOutput):
pass
class InstructBlipVideoQFormerAttention(InstructBlipQFormerAttention):
pass
class InstructBlipVideoQFormerIntermediate(InstructBlipQFormerIntermediate):
pass
class InstructBlipVideoQFormerOutput(InstructBlipQFormerOutput):
pass
class InstructBlipVideoQFormerLayer(InstructBlipQFormerLayer):
pass
class InstructBlipVideoQFormerEncoder(InstructBlipQFormerEncoder):
pass
class InstructBlipVideoQFormerEmbeddings(InstructBlipQFormerEmbeddings):
pass
class InstructBlipVideoQFormerModel(InstructBlipQFormerModel):
pass
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration, GenerationMixin):
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
def forward(
self,
pixel_values: torch.FloatTensor,
@ -146,15 +188,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size]`
Returns:
Examples:
```python
>>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
>>> import torch
@ -339,7 +372,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
"""
r"""
Overrides `generate` function to be able to use the model as a conditional generator.
Args:

View File

@ -1,8 +1,8 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# This file was automatically generated from <path_to_modular_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
@ -20,17 +20,10 @@
# limitations under the License.
from transformers import PretrainedConfig
from ...utils import (
logging,
)
from ...configuration_utils import PretrainedConfig
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class LlavaNextVideoConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlavaNextVideoForConditionalGeneration`]. It is used to instantiate an

View File

@ -1,8 +1,8 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# This file was automatically generated from <path_to_modular_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
@ -19,7 +19,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@ -130,6 +129,12 @@ def unpad_image(tensor, original_size):
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
if not isinstance(original_size, (list, tuple)):
if not isinstance(original_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
)
original_size = original_size.tolist()
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
@ -180,6 +185,7 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
@ -191,6 +197,7 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
video_hidden_states: Optional[torch.FloatTensor] = None
@ -455,7 +462,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
self.vocab_size = model_embeds.num_embeddings
return model_embeds
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration._merge_input_ids_with_image_features
def _merge_input_ids_with_image_features(
self,
image_features,
@ -695,7 +701,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
def pack_image_features(self, image_features, image_sizes, image_newline=None):
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -704,6 +710,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
@ -718,8 +726,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
if height * width != base_image_feature.shape[0]:
if vision_feature_select_strategy == "default":
expected_num_patches = height * width
elif vision_feature_select_strategy == "full":
expected_num_patches = height * width + 1
if expected_num_patches != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,

View File

@ -21,15 +21,13 @@ import torch
import torch.utils.checkpoint
from torch import nn
from transformers import PretrainedConfig
from transformers.models.llava_next.modeling_llava_next import (
LlavaNextCausalLMOutputWithPast,
LlavaNextForConditionalGeneration,
LlavaNextMultiModalProjector,
image_size_to_num_patches,
)
from ...generation import GenerationMixin
from ...configuration_utils import PretrainedConfig
from ...utils import (
logging,
replace_return_docstrings,
@ -56,18 +54,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
video_token_index (`int`, *optional*, defaults to 32000):
The video token index to encode the image prompt.
image_token_index (`int`, *optional*, defaults to 32001):
The image token index to encode the image prompt.
spatial_pool_mode (`str`, *optional*, defaults to `"average"`):
Pooling mode to use for videos. Can be "average", "max" or "conv".
spatial_pool_stride (`int`, *optional*, defaults to 2):
Stride used in the pooling layer for videos.
image_seq_length (`int`, *optional*, defaults to 576):
Sequence length of one image embedding.
video_seq_length (`int`, *optional*, defaults to 288):
Sequence length of one video embedding.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
The activation function used by the multimodal projector.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
@ -81,6 +69,16 @@ class LlavaNextVideoConfig(PretrainedConfig):
of the form `(height, width)`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
video_token_index (`int`, *optional*, defaults to 32000):
The video token index to encode the image prompt.
spatial_pool_mode (`str`, *optional*, defaults to `"average"`):
Pooling mode to use for videos. Can be "average", "max" or "conv".
spatial_pool_stride (`int`, *optional*, defaults to 2):
Stride used in the pooling layer for videos.
image_seq_length (`int`, *optional*, defaults to 576):
Sequence length of one image embedding.
video_seq_length (`int`, *optional*, defaults to 288):
Sequence length of one video embedding.
Example:
@ -178,7 +176,13 @@ class LlavaNextVideoConfig(PretrainedConfig):
@dataclass
class LlavaNextVideoCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
pass
"""
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
video_hidden_states: Optional[torch.FloatTensor] = None
class LlavaNextVideoPooler(nn.Module):
@ -215,11 +219,7 @@ class LlavaNextVideoPooler(nn.Module):
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector):
pass
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration, GenerationMixin):
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
super().__init__(config, **super_kwargs)
self.vision_resampler = LlavaNextVideoPooler(config)
@ -287,6 +287,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
r"""
Args:
@ -298,6 +300,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
@ -329,7 +335,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
... frames.append(frame)
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
>>> model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", device_map="auto)
>>> model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", device_map="auto")
>>> processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
>>> prompt = "USER: <video>\nWhy is this video funny? ASSISTANT:"
@ -499,6 +505,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
)
logits = outputs[0]
@ -529,6 +536,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
def prepare_inputs_for_generation(
@ -541,6 +550,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
image_sizes=None,
attention_mask=None,
cache_position=None,
num_logits_to_keep=None,
**kwargs,
):
if input_ids is not None:
@ -560,6 +570,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
**kwargs,
)

View File

@ -182,6 +182,7 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.

View File

@ -0,0 +1,76 @@
import argparse
import difflib
import glob
import logging
from io import StringIO
# Console for rich printing
from modular_model_converter import convert_modular_file
from rich.console import Console
from rich.syntax import Syntax
logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)
console = Console()
def process_file(modular_file_path, generated_modeling_content, file_type="modeling_", fix_and_overwrite=False):
file_path = modular_file_path.replace("modular_", f"{file_type}_")
# Read the actual modeling file
with open(file_path, "r") as modeling_file:
content = modeling_file.read()
output_buffer = StringIO(generated_modeling_content[file_type][0])
output_buffer.seek(0)
output_content = output_buffer.read()
diff = difflib.unified_diff(
output_content.splitlines(),
content.splitlines(),
fromfile=f"{file_path}_generated",
tofile=f"{file_path}",
lineterm="",
)
diff_list = list(diff)
# Check for differences
if diff_list:
if fix_and_overwrite:
with open(file_path, "w") as modeling_file:
modeling_file.write(generated_modeling_content[file_type][0])
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
else:
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
diff_text = "\n".join(diff_list)
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
console.print(syntax)
return 1
else:
console.print(f"[bold green]No differences found for {file_path}.[/bold green]")
return 0
def compare_files(modular_file_path, fix_and_overwrite=False):
# Generate the expected modeling content
generated_modeling_content = convert_modular_file(modular_file_path)
diff = 0
for file_type in generated_modeling_content.keys():
diff += process_file(modular_file_path, generated_modeling_content, file_type, fix_and_overwrite)
return diff
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
parser.add_argument(
"--files", default=["all"], type=list, nargs="+", help="List of modular_xxx.py files to compare."
)
parser.add_argument(
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
)
args = parser.parse_args()
if args.files == ["all"]:
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
non_matching_files = 0
for modular_file_path in args.files:
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
if non_matching_files and not args.fix_and_overwrite:
raise ValueError("Some diff and their modeling code did not match.")

View File

@ -0,0 +1,69 @@
import ast
from collections import defaultdict, deque
# Function to perform topological sorting
def topological_sort(dependencies):
# Create a graph and in-degree count for each node
graph = defaultdict(list)
in_degree = defaultdict(int)
# Build the graph
for node, deps in dependencies.items():
for dep in deps:
graph[dep].append(node) # node depends on dep
in_degree[node] += 1 # increase in-degree of node
# Add all nodes with zero in-degree to the queue
zero_in_degree_queue = deque([node for node in dependencies if in_degree[node] == 0])
sorted_list = []
# Perform topological sorting
while zero_in_degree_queue:
current = zero_in_degree_queue.popleft()
sorted_list.append(current)
# For each node that current points to, reduce its in-degree
for neighbor in graph[current]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
zero_in_degree_queue.append(neighbor)
# Handle nodes that have no dependencies and were not initially part of the loop
for node in dependencies:
if node not in sorted_list:
sorted_list.append(node)
return sorted_list
# Function to extract class and import info from a file
def extract_classes_and_imports(file_path):
with open(file_path, "r") as file:
tree = ast.parse(file.read(), filename=file_path)
imports = set()
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
module = node.module if isinstance(node, ast.ImportFrom) else None
if module and "transformers" in module:
imports.add(module)
return imports
# Function to map dependencies between classes
def map_dependencies(py_files):
dependencies = defaultdict(set)
# First pass: Extract all classes and map to files
for file_path in py_files:
dependencies[file_path].add(None)
class_to_file = extract_classes_and_imports(file_path)
for module in class_to_file:
dependencies[file_path].add(module)
return dependencies
def find_priority_list(py_files):
dependencies = map_dependencies(py_files)
ordered_classes = topological_sort(dependencies)
return ordered_classes[::-1]

View File

@ -20,21 +20,23 @@ from typing import Dict
import libcst as cst
from check_copies import run_ruff
from create_dependency_mapping import find_priority_list
from libcst import ClassDef, CSTTransformer, CSTVisitor
from libcst import matchers as m
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
from transformers import logging
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# This file was automatically generated from <path_to_modular_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
"""
@ -82,12 +84,16 @@ class ClassFinder(CSTVisitor):
self.function_def = {} # stores global scope function definition
self.assignments = {} # LLAMA_DOCSTRING
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
self.first_lvl_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
# fmt: on
def _update_class_dependency(self, name, value):
"""Update the dependency mapping for `name` with `value` by appending the previous
dependencies to the new `value`.
"""
dep = set(self.first_lvl_dependency_mapping.get(name, set())) | set({value})
self.first_lvl_dependency_mapping[name] = dep
dep = set(self.class_dependency_mapping.get(value, set()))
dep |= set(self.class_dependency_mapping.get(name, {})) | set({value})
self.class_dependency_mapping[name] = dep
@ -146,7 +152,16 @@ class ClassFinder(CSTVisitor):
def leave_Decorator(self, node):
if hasattr(node.decorator, "args"):
for k in node.decorator.args:
if k.value.value in self.assignments:
if m.matches(k.value, m.Call(func=m.Attribute(value=m.Name()))): # and k.value.func.value.value:
if k.value.func.value.value not in self.assignments:
raise ValueError(
f"We detected a call to {k.value.func.value.value}, but it was not assigned. See the list of assigments {self.assignments.keys()}"
)
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
self._update_class_dependency(name, k.value.func.value.value)
elif m.matches(k, m.Arg(value=m.Name())) and k.value.value in self.assignments:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
@ -178,6 +193,10 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
self.old_name = old_name
self.new_name = new_name
self.default_name = "".join(x.title() for x in new_name.split("_"))
if self.new_name in CONFIG_MAPPING_NAMES:
self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
"Config", ""
) # the best source of truth for class names. Could also just use the ones de
self.patterns = {
old_name: new_name,
old_name.upper(): new_name.upper(),
@ -193,7 +212,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
def replace(match):
word = match.group(0)
return self.patterns.get(word, self.default_name)
result = self.patterns.get(word, self.default_name)
return result
return compiled_regex.sub(replace, text)
@ -227,35 +247,102 @@ DOCSTRING_NODE = m.SimpleStatementLine(
)
def SUPER_CALL_NODE(func_name):
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
def get_docstring_indent(docstring):
# Match the first line after the opening triple quotes
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
if match:
# Return the indentation spaces captured
return len(match.group(1))
return 0
def merge_docstrings(original_docstring, updated_docstring):
# indent_level = get_docstring_indent(updated_docstring)
original_level = get_docstring_indent(original_docstring)
if " Args:\n " not in updated_docstring:
# Split the docstring at the example section, assuming `"""` is used to define the docstring
parts = original_docstring.split("```")
if "```" in updated_docstring and len(parts) > 1:
updated_docstring = updated_docstring.lstrip('r"')
new_parts = updated_docstring.split("```")
if len(new_parts) != 3:
raise ValueError("There should only be one example, and it should have opening and closing '```'")
parts[1] = new_parts[1]
updated_docstring = "".join(
[
parts[0].rstrip(" \n") + new_parts[0],
f"\n{original_level*' '}```",
parts[1],
"```",
parts[2],
]
)
elif updated_docstring not in original_docstring:
# add tabulation if we are at the lowest level.
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
updated_docstring = updated_docstring.replace("\n ", "\n ")
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
return updated_docstring
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, python_module: cst.Module, original_methods, updated_methods):
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name=""):
self.python_module = python_module
self.original_methods = original_methods
self.updated_methods = updated_methods
self.all_assign_target = {}
self.deleted_targets = {} # child node can delete some arguments
self.class_name = class_name
def update_body(self, existing_body, new_statements):
"""
Helper method to update the body by removing duplicates before adding new statements.
`existing_body` is the body of the original method, the parent class
`new_statements` are the additional statements
"""
deduplicated_new_body = []
existing_nodes = set()
for node in new_statements:
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(node.body[0].targets[0].target)
self.all_assign_target[target] = node
if m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
target = self.python_module.code_for_node(node.body[0].target)
self.deleted_targets[target] = node
continue
for stmt in existing_body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
if target in self.deleted_targets:
logger.warning(f"Deleted the assign for {target}")
continue
if target in self.all_assign_target:
stmt = self.all_assign_target[target]
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
deduplicated_new_body.append(stmt)
existing_nodes.add(comment_less_code)
for node in new_statements:
code = self.python_module.code_for_node(node)
comment_less_code = re.sub(r"#.*", "", code).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if (
node not in deduplicated_new_body
and "super().__init__" not in comment_less_code
and comment_less_code not in existing_nodes
):
if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
# HACK here to fix the pos_init() that has to be last we kinda do this.
deduplicated_new_body = deduplicated_new_body[:-1] + [node] + deduplicated_new_body[-1:]
existing_nodes.add(comment_less_code)
for stmt in existing_body:
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if comment_less_code not in existing_nodes:
if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
continue
deduplicated_new_body.append(stmt)
existing_nodes.add(stmt)
else:
logger.info(f"\nFound duplicate {self.python_module.code_for_node(stmt)}")
return deduplicated_new_body
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
@ -263,26 +350,37 @@ class SuperTransformer(cst.CSTTransformer):
to super().func_name() with the source code of the parent class' `func_name`.
It keeps everything that is defined before `super().func_name()`.
"""
new_body = []
self.has_docstring = False
for expr in node.body:
self.has_docstring = m.matches(node.body[0], DOCSTRING_NODE)
parent_has_docstring = False
if func_name in self.original_methods:
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
new_body = []
has_super_call = False
for idx, expr in enumerate(node.body):
if m.matches(
expr,
m.SimpleStatementLine(
body=[
m.Return(
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
)
| m.Expr(
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
)
]
body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]
),
):
if idx != 0 and func_name == "__init__":
raise ValueError(f"The call to super() in {self.class_name} should be at the top of the init")
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
has_super_call = True
elif m.matches(expr, DOCSTRING_NODE):
self.has_docstring = True
if parent_has_docstring: # actually here we ought to de-duplicate?
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
updated_docstring = expr.body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
else:
new_node = [expr]
new_body.extend(new_node)
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
new_body.append(expr)
if not self.has_docstring and parent_has_docstring:
new_body = [self.original_methods[func_name].body.body[0]] + new_body
return node.with_changes(body=new_body)
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
@ -330,14 +428,22 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
| ```
"""
original_node = class_finder.classes[class_name]
original_methods = {f.name.value if hasattr(f, "name") else f: f for f in original_node.body.body}
updated_methods = {f.name.value if hasattr(f, "name") else f: f for f in updated_node.body.body}
original_methods = {
f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f
for f in original_node.body.body
}
updated_methods = {
f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f
for f in updated_node.body.body
}
end_meth = []
assign_targets = {}
docstring_node = []
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
for func in original_node.body.body:
name = func.name.value if hasattr(func, "name") else func
if name in updated_methods and updated_methods[name] is not None:
name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func)
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
new_params = updated_methods[name].params
# Replace the method in the replacement class, preserving decorators
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
@ -348,22 +454,61 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
)
func = func.with_changes(body=updated_methods[name].body, params=new_params)
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
target = class_finder.python_module.code_for_node(func.body[0].target)
assign_targets[target] = func
elif m.matches(func, DOCSTRING_NODE):
docstring_node = [func]
else:
end_meth.append(func)
# Port new methods that are defined only in diff-file and append at the end
for name, func in updated_methods.items():
# Port new methods that are defined only in modular-file and append at the end
for func in updated_node.body.body:
name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func)
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
# Extract the original docstring
updated_docstring = func.body[0].value.value
original_docstring = docstring_node[0].body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
# Update the docstring in the original function
docstring_node = [
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
]
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
end_meth.append(func)
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
# TODO we only use single assign might cause issues
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
target = class_finder.python_module.code_for_node(func.body[0].target)
assign_targets[target] = func
end_meth = docstring_node + list(assign_targets.values()) + end_meth
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods))
new_replacement_class = new_module.visit(
SuperTransformer(temp_module, original_methods, updated_methods, class_name)
)
new_replacement_body = new_replacement_class.body[0].body # get the indented block
return original_node.with_changes(body=new_replacement_body)
class DiffConverterTransformer(CSTTransformer):
TYPE_TO_FILE_TYPE = {
"Config": "configuration",
"Tokenizer": "tokenization",
"Processor": "processor",
"ImageProcessor": "image_processing",
"FeatureExtractor": "feature_extractor",
}
class ModularConverterTransformer(CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
@ -378,11 +523,21 @@ class DiffConverterTransformer(CSTTransformer):
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
self.new_body = {} # store the new body, all global scope nodes should be added here
self.inserted_deps = [] # nodes inserted via super dependency
self.all_imports = [] # just stores all of the imports
self.all_safe_imports = [] # stores the import under simple statements
self.global_scope_index = 0
# fmt: on
self.files = { # mapping for different component bodies
"modeling": {},
"configuration": {},
"tokenization": {},
"processing": {},
"image_processing": {},
"feature_extractor": {},
}
self.match_patterns = "|".join(self.files.keys())
self.all_functions = {}
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""When visiting imports from `transformers.models.xxx` we need to:
@ -393,7 +548,7 @@ class DiffConverterTransformer(CSTTransformer):
import_statement = self.python_module.code_for_node(node.module)
if m.matches(node.module, m.Attribute()):
for imported_ in node.names:
_import = re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", import_statement)
_import = re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", import_statement)
if _import:
source = _import.groups()[0]
if source == "modeling" and "Config" in self.python_module.code_for_node(imported_):
@ -401,44 +556,38 @@ class DiffConverterTransformer(CSTTransformer):
f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead"
)
if import_statement not in self.transformers_imports:
if "models" not in import_statement:
import_statement = "models." + import_statement
if "transformers" not in import_statement:
import_statement = "transformers." + import_statement
source_code = get_module_source_from_name(import_statement)
tree = cst.parse_module(source_code)
self.transformers_imports[import_statement] = tree
imported_class = self.python_module.code_for_node(imported_.name)
self.imported_mapping[imported_class] = import_statement
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.global_scope_index += 100
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
return node
if m.matches(node.module, m.Name()):
if "transformers" == import_statement:
raise ValueError(
f"You are importing from {import_statement} directly using global imports. Import from the correct local path"
)
def leave_SimpleStatementLine(self, original_node, updated_node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
if parent_node not in self.all_imports:
if updated_node not in self.all_imports:
self.all_imports.append(updated_node)
return updated_node
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
if re.search(
rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement
): # OR MATCH ..llama.modeling_llama
return cst.RemoveFromParent()
if parent_node not in self.all_imports:
if updated_node not in self.all_imports:
self.all_imports.append(updated_node)
return updated_node
self.global_scope_index += 100
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Assign()])):
# TODO This only works for single target assigns!
node_name = updated_node.body[0].targets[0].target.value
else:
node_name = self.python_module.code_for_node(updated_node.body[0])
self.new_body[node_name] = {
"insert_idx": self.global_scope_index,
"node": updated_node,
}
self.config_body = [updated_node]
return updated_node
def leave_ClassDef(self, original_node, updated_node):
@ -454,6 +603,7 @@ class DiffConverterTransformer(CSTTransformer):
"""
class_name = original_node.name.value
bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping]
all_bases = [k.value.value for k in original_node.bases]
self.global_scope_index += 100
for super_class in bases:
if super_class not in self.imported_mapping:
@ -469,7 +619,7 @@ class DiffConverterTransformer(CSTTransformer):
raise ValueError(
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
)
file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0]
visited_module = self.visited_module
if super_file_name not in visited_module: # only extract classes once
class_finder = find_classes_in_file(
@ -490,22 +640,47 @@ class DiffConverterTransformer(CSTTransformer):
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
start_insert_idx = self.global_scope_index
file_to_update = self.files[file_type]
is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n"
for dependency, _ in list_dependencies:
# we can write to the correct body, using the source of the parent class
node = class_finder.global_nodes.get(dependency, None)
if node is not None and "Config" not in class_name:
if dependency not in self.new_body:
if node is not None:
if dependency not in file_to_update:
start_insert_idx -= 1
self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node}
file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node}
elif dependency not in self.inserted_deps:
# make sure the node is written after its dependencies
start_insert_idx = self.new_body[dependency]["insert_idx"] - 1
start_insert_idx = file_to_update[dependency]["insert_idx"] - 1
if (
dependency in file_to_update.keys()
and dependency in class_finder.first_lvl_dependency_mapping[class_name]
):
# If dependency is defined, but not used, raise error
calls = m.findall(original_node, m.Call(func=m.Name(dependency)))
if not calls and not is_empty_node and dependency not in all_bases:
raise ValueError(
f"""You defined `{dependency}` in the modular_{self.model_name}.py, it should be used
when you define `{class_name}`, as it is one of it's direct dependencies. Make sure
you use it in the `__init__` function."""
)
self.inserted_deps.append(dependency)
if len(list_dependencies) > 0:
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
if "Config" in class_name:
self.config_body += [updated_node]
else:
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
raise ValueError(
f"Unable to find dependencies for {super_class} in {super_file_name}. Here are the dependencies found: {class_finder.class_dependency_mapping}. (The automatic renaming might have gone wrong!)"
)
# Now, if a class was defined without parents, we look for the name
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
match = re.search(rf"({match_pattern})$", class_name)
if match:
key = TYPE_TO_FILE_TYPE[match.group(1)]
self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
else:
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
return updated_node
def leave_If(self, original_node, node):
@ -513,66 +688,69 @@ class DiffConverterTransformer(CSTTransformer):
if m.matches(parent_node, m.Module()):
full_statement = self.python_module.code_for_node(original_node.test)
if re.search(r"[\s\S]*is_.*available", full_statement):
self.all_imports.append(node)
self.all_safe_imports.append(node)
elif full_statement not in self.new_body:
self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node}
return node
def leave_Module(self, original_node: cst.Assign, node):
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
dependency_imports = {}
config_imports = []
for visiter in self.visited_module.values():
dependency_imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()})
dependency_imports = {file_type: imports.copy() for file_type in self.files}
for super_file_name, visiter in self.visited_module.items():
file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0]
dependency_imports[file_type].update(
{self.python_module.code_for_node(k): k for k in visiter.imports.values()}
)
# manually clean up if it's importing a config from configuration file (ruff doesn't do that)
config_imports = []
for i in list(dependency_imports.values()):
if (
hasattr(i.body[0], "module")
and isinstance(i.body[0].module, cst.Name)
and f"configuration_{self.model_name}" in i.body[0].module.value
):
pass
else:
config_imports.append(i)
if hasattr(self, "config_body"):
self.config_body = list(imports.values()) + config_imports + self.config_body
dependency_imports.update(imports)
new_body = list(dependency_imports.values())
if len(self.new_body.keys()) > 0:
new_body += [k[1]["node"] for k in sorted(self.new_body.items(), key=lambda x: x[1]["insert_idx"])]
else:
new_body = []
return node.with_changes(body=[*new_body])
for file, body in self.files.items():
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
if len(new_body) > 0:
if file in dependency_imports.keys():
new_body = list(dependency_imports[file].values()) + new_body
self.files[file] = cst.Module(body=[*new_body], header=node.header)
return node
def convert_file(diff_file, old_model_name=None, new_model_name=None, cst_transformers=None):
model_name = re.search(r"diff_(.*)(?=\.py$)", diff_file).groups()[0]
def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None):
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
output = {}
if pattern is not None:
model_name = pattern.groups()[0]
# Parse the Python file
with open(diff_file, "r") as file:
with open(modular_file, "r") as file:
code = file.read()
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
if cst_transformers is None:
cst_transformers = DiffConverterTransformer(module, model_name, old_model_name, new_model_name)
new_mod = wrapper.visit(cst_transformers)
ruffed_code = run_ruff(new_mod.code, True)
cst_transformers = ModularConverterTransformer(module, model_name, old_model_name, new_model_name)
wrapper.visit(cst_transformers)
for file, node in cst_transformers.files.items():
if node != {}:
ruffed_code = run_ruff(AUTO_GENERATED_MESSAGE + node.code, True)
formatted_code = run_ruff(ruffed_code, False)
if len(formatted_code.strip()) > 0:
with open(diff_file.replace("diff_", "modeling_"), "w") as f:
f.write(AUTO_GENERATED_MESSAGE + formatted_code)
output[file] = [formatted_code, ruffed_code]
return output
else:
print(f"modular pattern not found in {modular_file}, exiting")
return {}
if hasattr(cst_transformers, "config_body"):
config_module = cst.Module(body=[*cst_transformers.config_body], header=new_mod.header)
with open(diff_file.replace("diff_", "configuration_"), "w") as f:
ruffed_code = run_ruff(config_module.code, True)
formatted_code = run_ruff(ruffed_code, False)
f.write(AUTO_GENERATED_MESSAGE + formatted_code)
# TODO optimize by re-using the class_finder
return cst_transformers
def save_modeling_file(modular_file, converted_file):
for file_type in converted_file.keys():
non_comment_lines = len(
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
)
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f:
f.write(converted_file[file_type][0])
else:
non_comment_lines = len(
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
)
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
logger.warning("The modeling code contains errors, it's written without formatting")
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f:
f.write(converted_file[file_type][1])
if __name__ == "__main__":
@ -581,22 +759,24 @@ if __name__ == "__main__":
"--files_to_parse",
default=["all"],
nargs="+",
help="A list of `diff_xxxx` files that should be converted to single model file",
help="A list of `modular_xxxx` files that should be converted to single model file",
)
parser.add_argument(
"--old_model_name",
required=False,
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from diff-file",
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from modular-file",
)
parser.add_argument(
"--new_model_name",
required=False,
help="The name of the new model being added in CamelCase. If not provided is inferred from diff-file",
help="The name of the new model being added in CamelCase. If not provided is inferred from modular-file",
)
args = parser.parse_args()
if args.files_to_parse == ["all"]:
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
for file_name in args.files_to_parse:
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
for file_name in find_priority_list(args.files_to_parse):
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converter = convert_file(file_name, args.old_model_name, args.new_model_name)
converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
converter = save_modeling_file(file_name, converted_files)