mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Modular transformers
: modularity and inheritance for new model additions (#33248)
* update exampel * update * push the converted diff files for testing and ci * correct one example * fix class attributes and docstring * nits * oups * fixed config! * update * nitd * class attributes are not matched against the other, this is missing * fixed overwriting self.xxx now onto the attributes I think * partial fix, now order with docstring * fix docstring order? * more fixes * update * fix missing docstrings! * examples don't all work yet * fixup * nit * updated * hick * update * delete * update * update * update * fix * all default * no local import * fix more diff * some fix related to "safe imports" * push fixed * add helper! * style * add a check * all by default * add the * update * FINALLY! * nit * fix config dependencies * man that is it * fix fix * update diffs * fix the last issue * re-default to all * alll the fixes * nice * fix properties vs setter * fixup * updates * update dependencies * make sure to install what needs to be installed * fixup * quick fix for now * fix! * fixup * update * update * updates * whitespaces * nit * fix * simplify everything, and make it file agnostic (should work for image processors) * style * finish fixing all import issues * fixup * empty modeling should not be written! * Add logic to find who depends on what * update * cleanup * update * update gemma to support positions * some small nits * this is the correct docstring for gemma2 * fix merging of docstrings * update * fixup * update * take doc into account * styling * update * fix hidden activation * more fixes * final fixes! * fixup * fixup instruct blip video * update * fix bugs * align gemma2 with the rest as well * updats * revert * update * more reversiom * grind * more * arf * update * order will matter * finish del stuff * update * rename to modular * fixup * nits * update makefile * fixup * update order of the checks! * fix * fix docstring that has a call inside * fiix conversion check * style * add some initial documentation * update * update doc * some fixup * updates * yups * Mostly todo gimme a minut * update * fixup * revert some stuff * Review docs for the modular transformers (#33472) Docs * good update * fixup * mmm current updates lead to this code * okay, this fixes it * cool * fixes * update * nit * updates * nits * fix doc * update * revert bad changes * update * updates * proper update * update * update? * up * update * cool * nits * nits * bon bon * fix * ? * minimise changes * update * update * update * updates? * fixed gemma2 * kind of a hack * nits * update * remove `diffs` in favor of `modular` * fix make fix copies --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
parent
75b7485cc7
commit
317e069ee7
@ -137,7 +137,7 @@ jobs:
|
||||
parallelism: 1
|
||||
steps:
|
||||
- checkout
|
||||
- run: uv pip install -e .
|
||||
- run: uv pip install -e ".[quality]"
|
||||
- run:
|
||||
name: Show installed libraries and their versions
|
||||
command: pip freeze | tee installed.txt
|
||||
@ -162,13 +162,14 @@ jobs:
|
||||
parallelism: 1
|
||||
steps:
|
||||
- checkout
|
||||
- run: uv pip install -e .
|
||||
- run: uv pip install -e ".[quality]"
|
||||
- run:
|
||||
name: Show installed libraries and their versions
|
||||
command: pip freeze | tee installed.txt
|
||||
- store_artifacts:
|
||||
path: ~/transformers/installed.txt
|
||||
- run: python utils/check_copies.py
|
||||
- run: python utils/check_modular_conversion.py
|
||||
- run: python utils/check_table.py
|
||||
- run: python utils/check_dummies.py
|
||||
- run: python utils/check_repo.py
|
||||
|
2
Makefile
2
Makefile
@ -36,6 +36,7 @@ autogenerate_code: deps_table_update
|
||||
|
||||
repo-consistency:
|
||||
python utils/check_copies.py
|
||||
python utils/check_modular_conversion.py
|
||||
python utils/check_table.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
@ -80,6 +81,7 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
|
||||
|
||||
fix-copies:
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
python utils/check_modular_conversion.py --fix_and_overwrite
|
||||
python utils/check_table.py --fix_and_overwrite
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
python utils/check_doctest_list.py --fix_and_overwrite
|
||||
|
@ -5,6 +5,8 @@
|
||||
title: Quick tour
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: add_new_model
|
||||
title: Adding a new model to `transformers`
|
||||
title: Get started
|
||||
- sections:
|
||||
- local: pipeline_tutorial
|
||||
@ -149,6 +151,8 @@
|
||||
title: Interoperability with GGUF files
|
||||
- local: tiktoken
|
||||
title: Interoperability with TikToken files
|
||||
- local: modular_transformers
|
||||
title: Modularity in `transformers`
|
||||
title: Developer guides
|
||||
- sections:
|
||||
- local: quantization/overview
|
||||
|
121
docs/source/en/modular_transformers.md
Normal file
121
docs/source/en/modular_transformers.md
Normal file
@ -0,0 +1,121 @@
|
||||
# Modular transformers
|
||||
|
||||
`transformers` is an opinionated framework; our philosophy is defined in the following [conceptual guide](./philosophy).
|
||||
|
||||
The core of that philosophy is exemplified by the [single model, single file](https://huggingface.co/blog/transformers-design-philosophy)
|
||||
aspect of the library. This component's downside is that it limits the inheritance and importability of components from
|
||||
files to others in the toolkit.
|
||||
|
||||
As a result, model components tend to be repeated across many files. There are as many attention layers defined
|
||||
in `transformers` as there are models, and a significant number of those are identical to each other.
|
||||
The unfortunate consequence is that independent implementations tend to diverge as fixes and changes get applied
|
||||
to specific parts of the code.
|
||||
|
||||
In order to balance this issue, we introduced the concept of "copies" across the library. By adding a comment indicating
|
||||
that code is a copy of another, we can enforce through CI and local commands that copies do not diverge. However,
|
||||
while the complexity is low, this is often quite tedious to do.
|
||||
|
||||
And, finally, this contributes to adding a significant overhead to contributing models which we would like to remove.
|
||||
This approach often requires model contributions to add modeling code (~1k lines), processor (~500 lines), tests, docs,
|
||||
etc. Model contribution PRs rarely add less than 3-5k lines of code, with much of this code being boilerplate.
|
||||
|
||||
This raises the bar for contributions, and with Modular Transformers, we're aiming to lower the bar to a much more
|
||||
acceptable point.
|
||||
|
||||
## What is it?
|
||||
|
||||
Modular Transformers introduces the concept of a "modular" file to a model folder. This modular file accepts code
|
||||
that isn't typically accepted in modeling/processing files, as it allows importing from neighbouring models as well
|
||||
as inheritance from classes to others.
|
||||
|
||||
This modular file defines models, processors, and the configuration class that would otherwise be defined in their
|
||||
respective modules.
|
||||
|
||||
Finally, this feature introduces a new `linter` which will "unravel" the modular file into the "single model, single
|
||||
file" directory structure. These files will get auto-generated every time the script is run; reducing the required
|
||||
contributions to the modular file, and therefore only to the changes between the contributed model and others.
|
||||
|
||||
Model users will end up importing and using the single-file interface, so no change is expected here. Doing this, we
|
||||
hope to combine the best of both worlds: enabling simple contributions while sticking to our philosophy.
|
||||
|
||||
This is therefore a replacement for the `# Copied from` markers, and previously contributed models can be expected to
|
||||
be moved to the new Modular Transformers format in the coming months.
|
||||
|
||||
### Details
|
||||
|
||||
The "linter", which unravels the inheritance and creates all single-files from the modular file, will flatten the
|
||||
inheritance while trying to be invisible to Python users. At this time, the linter flattens a **single** level of
|
||||
inheritance.
|
||||
|
||||
For example:
|
||||
- If a configuration class inherits from another and adds/deletes an argument, the generated file will either directly
|
||||
reference it (in case of addition) or completely remove it (in case of deletion).
|
||||
- If a class inherits from another, for example: class GemmaModel(LlamaModel):, dependencies are automatically
|
||||
inferred. All submodules will be automatically inferred from the superclass.
|
||||
|
||||
You should be able to write everything (the tokenizer, the image processor, the model, the config) in this `modular`
|
||||
file, and the corresponding files will be created for you.
|
||||
|
||||
### Enforcement
|
||||
|
||||
[TODO] We are introducing a new test, that makes sure the generated content matches what is present in the `modular_xxxx.py`
|
||||
|
||||
### Examples
|
||||
|
||||
Here is a quick example with BERT and RoBERTa. The two models are intimately related: their modeling implementation
|
||||
differs solely by a change in the embedding layer.
|
||||
|
||||
Instead of redefining the model entirely, here is what the `modular_roberta.py` file looks like for the modeling &
|
||||
configuration classes (for the sake of the example, the tokenizer is ignored at this time as very different).
|
||||
|
||||
```python
|
||||
from torch import nn
|
||||
from ..bert.configuration_bert import BertConfig
|
||||
from ..bert.modeling_bert import (
|
||||
BertModel,
|
||||
BertEmbeddings,
|
||||
BertForMaskedLM
|
||||
)
|
||||
|
||||
# The RoBERTa config is identical to BERT's config
|
||||
class RobertaConfig(BertConfig):
|
||||
model_type = 'roberta'
|
||||
|
||||
# We redefine the embeddings here to highlight the padding ID difference, and we redefine the position embeddings
|
||||
class RobertaEmbeddings(BertEmbeddings):
|
||||
def __init__(self, config):
|
||||
super().__init__(config())
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
||||
)
|
||||
|
||||
# The RoBERTa model is identical to the BERT model, except for the embedding layer.
|
||||
# We redefine the embeddings above, so here there is no need to do additional work
|
||||
class RobertaModel(BertModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
|
||||
|
||||
# The heads now only need to redefine the model inside to the correct `RobertaModel`
|
||||
class RobertaForMaskedLM(BertForMaskedLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = RobertaModel(config)
|
||||
```
|
||||
|
||||
Note that if you do not use the dependency that you defined, you will have the following error:
|
||||
|
||||
```bash
|
||||
ValueError: You defined `RobertaEmbeddings` in the modular_roberta.py, it should be used
|
||||
when you define `BertModel`, as it is one of it's direct dependencies. Make sure
|
||||
you use it in the `__init__` function.
|
||||
```
|
||||
|
||||
Additionally, you may find a list of examples here:
|
||||
|
||||
## What it is not
|
||||
|
||||
It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual.
|
@ -1,20 +0,0 @@
|
||||
# Using the `diff_converter` linter
|
||||
|
||||
`pip install libcst` is a must!
|
||||
|
||||
# `sh examples/diff-conversion/convert_examples.sh` to get the converted outputs
|
||||
|
||||
The diff converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular `diff` file like `diff_gemma.py` into a `single model single file`.
|
||||
|
||||
Examples of possible usage are available in the `examples/diff-conversion`, or `diff_gemma` for a full model usage.
|
||||
|
||||
`python utils/diff_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model2.py"`
|
||||
|
||||
## How it works
|
||||
We use the `libcst` parser to produce an AST representation of the `diff_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the difference dependencies.
|
||||
|
||||
The code from the `diff` file and the class dependency mapping are "merged" to produce the single model single file.
|
||||
We use ruff to automatically remove the potential duplicate imports.
|
||||
|
||||
## Why we use libcst instead of the native AST?
|
||||
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`
|
20
examples/modular-transformers/README.md
Normal file
20
examples/modular-transformers/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# Using the `modular_converter` linter
|
||||
|
||||
`pip install libcst` is a must!
|
||||
|
||||
# `sh examples/modular-transformers/convert_examples.sh` to get the converted outputs
|
||||
|
||||
The modular converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular file like `modular_gemma.py` into a `single model single file`.
|
||||
|
||||
Examples of possible usage are available in the `examples/modular-transformers`, or `modular_gemma` for a full model usage.
|
||||
|
||||
`python utils/modular_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/modular-transformers/modular_my_new_model2.py"`
|
||||
|
||||
## How it works
|
||||
We use the `libcst` parser to produce an AST representation of the `modular_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the modularerence dependencies.
|
||||
|
||||
The code from the `modular` file and the class dependency mapping are "merged" to produce the single model single file.
|
||||
We use ruff to automatically remove the potential duplicate imports.
|
||||
|
||||
## Why we use libcst instead of the native AST?
|
||||
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`
|
196
examples/modular-transformers/configuration_my_new_model.py
Normal file
196
examples/modular-transformers/configuration_my_new_model.py
Normal file
@ -0,0 +1,196 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class MyNewModelConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MyNewModelModel`]. It is used to instantiate an MyNewModel
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the MyNewModel-7B.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the MyNewModel model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`MyNewModelModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
|
||||
MyNewModel 2 up to 4096, CodeMyNewModel up to 16384.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
||||
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
||||
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'my_new_model3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'my_new_model3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'my_new_model3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'my_new_model3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||
head_dim (`int`, *optional*):
|
||||
The attention head dimension. If None, it will default to hidden_size // num_heads
|
||||
new_param (`int`, *optional*, defaults to `False`):
|
||||
A fun new parameter
|
||||
|
||||
```python
|
||||
>>> from transformers import MyNewModelModel, MyNewModelConfig
|
||||
|
||||
>>> # Initializing a MyNewModel my_new_model-7b style configuration
|
||||
>>> configuration = MyNewModelConfig()
|
||||
|
||||
>>> # Initializing a model from the my_new_model-7b style configuration
|
||||
>>> model = MyNewModelModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "my_new_model"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=True,
|
||||
head_dim=None,
|
||||
new_param=0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.mlp_bias = mlp_bias
|
||||
self.new_param = new_param
|
97
examples/modular-transformers/configuration_my_new_model2.py
Normal file
97
examples/modular-transformers/configuration_my_new_model2.py
Normal file
@ -0,0 +1,97 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class MyNewModel2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma-7B.
|
||||
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256000):
|
||||
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`GemmaModel`]
|
||||
```python
|
||||
>>> from transformers import GemmaModel, GemmaConfig
|
||||
>>> # Initializing a Gemma gemma-7b style configuration
|
||||
>>> configuration = GemmaConfig()
|
||||
>>> # Initializing a model from the gemma-7b style configuration
|
||||
>>> model = GemmaModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "my_new_model2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=False,
|
||||
head_dim=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
134
examples/modular-transformers/configuration_new_model.py
Normal file
134
examples/modular-transformers/configuration_new_model.py
Normal file
@ -0,0 +1,134 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# Example where we only want to overwrite the defaults of an init
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class NewModelConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`NewModelModel`]. It is used to instantiate an NewModel
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the NewModel-7B.
|
||||
e.g. [google/new_model-7b](https://huggingface.co/google/new_model-7b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256000):
|
||||
Vocabulary size of the NewModel model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`NewModelModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The legacy activation function. It is overwritten by the `hidden_activation`.
|
||||
hidden_activation (`str` or `function`, *optional*):
|
||||
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 2):
|
||||
Beginning of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
```python
|
||||
>>> from transformers import NewModelModel, NewModelConfig
|
||||
>>> # Initializing a NewModel new_model-7b style configuration
|
||||
>>> configuration = NewModelConfig()
|
||||
>>> # Initializing a model from the new_model-7b style configuration
|
||||
>>> model = NewModelModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "new_model"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256030,
|
||||
hidden_size=64,
|
||||
intermediate_size=90,
|
||||
num_hidden_layers=28,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation=None,
|
||||
max_position_embeddings=1500,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_activation = hidden_activation
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_heads(self):
|
||||
return self.num_attention_heads
|
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Iterate over each file in the current directory
|
||||
for file in examples/diff-conversion/diff_*; do
|
||||
for file in examples/modular-transformers/modular_*; do
|
||||
# Check if it's a regular file
|
||||
if [ -f "$file" ]; then
|
||||
# Call the Python script with the file name as an argument
|
1053
examples/modular-transformers/modeling_dummy.py
Normal file
1053
examples/modular-transformers/modeling_dummy.py
Normal file
File diff suppressed because it is too large
Load Diff
1038
examples/modular-transformers/modeling_dummy_bert.py
Normal file
1038
examples/modular-transformers/modeling_dummy_bert.py
Normal file
File diff suppressed because it is too large
Load Diff
1059
examples/modular-transformers/modeling_my_new_model2.py
Normal file
1059
examples/modular-transformers/modeling_my_new_model2.py
Normal file
File diff suppressed because it is too large
Load Diff
953
examples/modular-transformers/modeling_super.py
Normal file
953
examples/modular-transformers/modeling_super.py
Normal file
@ -0,0 +1,953 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, StaticCache
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
)
|
||||
from .configuration_super import SuperConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
min_dtype: float,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
min_dtype (`float`):
|
||||
The minimum value representable with the dtype `dtype`.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class SuperRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
SuperRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class SuperRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[SuperConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`SuperRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.45"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class SuperMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
if self.config.pretraining_tp > 1:
|
||||
slice = self.intermediate_size // self.config.pretraining_tp
|
||||
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
||||
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
||||
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
||||
|
||||
gate_proj = torch.cat(
|
||||
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
|
||||
)
|
||||
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
||||
|
||||
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
||||
down_proj = [
|
||||
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
||||
]
|
||||
down_proj = sum(down_proj)
|
||||
else:
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
return down_proj
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class SuperAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: SuperConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||
|
||||
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
|
||||
self.rotary_emb = SuperRotaryEmbedding(config=self.config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class SuperFlashAttention2(SuperAttention):
|
||||
"""
|
||||
Super flash attention module. This module inherits from `SuperAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (SuperRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class SuperSdpaAttention(SuperAttention):
|
||||
"""
|
||||
Super attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`SuperAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from SuperAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"SuperModel is using SuperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
SUPER_ATTENTION_CLASSES = {
|
||||
"eager": SuperAttention,
|
||||
"flash_attention_2": SuperFlashAttention2,
|
||||
"sdpa": SuperSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class SuperDecoderLayer(nn.Module):
|
||||
def __init__(self, config: SuperConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = SUPER_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = SuperMLP(config)
|
||||
self.input_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
SUPER_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`SuperConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Super Model outputting raw hidden-states without any specific head on top.",
|
||||
SUPER_START_DOCSTRING,
|
||||
)
|
||||
class SuperPreTrainedModel(PreTrainedModel):
|
||||
config_class = SuperConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SuperDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
SUPER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||
|
||||
Two formats are allowed:
|
||||
- a [`~cache_utils.Cache`] instance;
|
||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||
cache format.
|
||||
|
||||
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||
legacy cache format will be returned.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||
of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Super Model outputting raw hidden-states without any specific head on top.",
|
||||
SUPER_START_DOCSTRING,
|
||||
)
|
||||
class SuperModel(SuperPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuperDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: SuperConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: SuperConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[SuperDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = SuperRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(SUPER_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
out = super().forward(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
inputs_embeds,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
cache_position,
|
||||
)
|
||||
out.logits *= 2**4
|
||||
return out
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_length()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
min_dtype=min_dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
@ -3,10 +3,11 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
from ...cache_utils import Cache
|
||||
|
||||
|
||||
def _pre_process_input(input_ids):
|
||||
print(log(input_ids))
|
27
examples/modular-transformers/modular_dummy_bert.py
Normal file
27
examples/modular-transformers/modular_dummy_bert.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertModel
|
||||
|
||||
from ...modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
||||
|
||||
|
||||
class DummyBertModel(BertModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
return super().forward(input_ids)
|
@ -5,10 +5,11 @@ from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
# here there is no `ARG` so we are gonna take parent doc
|
||||
class MyNewModelConfig(LlamaConfig):
|
||||
r"""
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`)
|
||||
new_param (`int`, *optional*, defaults to `False`):
|
||||
A fun new parameter
|
||||
"""
|
||||
|
||||
def __init__(self, mlp_bias=True, new_param=0, **super_kwargs):
|
||||
super().__init__(self, **super_kwargs)
|
||||
self.mlp_bias = mlp_bias
|
||||
self.new_param = new_param
|
||||
super().__init__(self, **super_kwargs)
|
@ -26,5 +26,10 @@ class NewModelConfig(GemmaConfig):
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(self)
|
||||
super().__init__(self, **kwargs)
|
||||
|
||||
@property
|
||||
def num_heads(self):
|
||||
return self.num_attention_heads
|
20
examples/modular-transformers/modular_roberta.py
Normal file
20
examples/modular-transformers/modular_roberta.py
Normal file
@ -0,0 +1,20 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel
|
||||
|
||||
|
||||
class RobertaEmbeddings(BertEmbeddings):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.pad_token_id = config.pad_token_id
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings, config.hidden_size, config.pad_token_id
|
||||
)
|
||||
|
||||
|
||||
class RobertaModel(BertModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(self, config)
|
||||
# Error out here. Why? Because `RobertaEmbeddings` is defined but not used.
|
||||
# no, because it's defined, and RobertaModel should use RobertaEmbedding
|
||||
# here if initialized that way it won't use the new embedding.
|
@ -2,10 +2,11 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
from ...cache_utils import Cache
|
||||
|
||||
|
||||
# example where we need some deps and some functions
|
||||
class SuperModel(LlamaModel):
|
4
setup.py
4
setup.py
@ -192,6 +192,8 @@ _deps = [
|
||||
"urllib3<2.0.0",
|
||||
"uvicorn",
|
||||
"pytest-rich",
|
||||
"libcst",
|
||||
"rich",
|
||||
]
|
||||
|
||||
|
||||
@ -345,7 +347,7 @@ extras["testing"] = (
|
||||
|
||||
extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"]
|
||||
extras["ruff"] = deps_list("ruff")
|
||||
extras["quality"] = deps_list("datasets", "isort", "ruff", "GitPython", "urllib3")
|
||||
extras["quality"] = deps_list("datasets", "isort", "ruff", "GitPython", "urllib3", "libcst", "rich")
|
||||
|
||||
extras["all"] = (
|
||||
extras["tf"]
|
||||
|
@ -97,4 +97,6 @@ deps = {
|
||||
"urllib3": "urllib3<2.0.0",
|
||||
"uvicorn": "uvicorn",
|
||||
"pytest-rich": "pytest-rich",
|
||||
"libcst": "libcst",
|
||||
"rich": "rich",
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# This file was automatically generated from <path_to_modular_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_xxx.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
@ -21,7 +21,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class GemmaConfig(PretrainedConfig):
|
||||
|
@ -1,8 +1,8 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# This file was automatically generated from <path_to_modular_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_xxx.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
@ -39,7 +39,6 @@ from ...modeling_outputs import (
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@ -51,63 +50,6 @@ from ...utils import (
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
min_dtype: float,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
min_dtype (`float`):
|
||||
The minimum value representable with the dtype `dtype`.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@ -128,7 +70,7 @@ class GemmaRMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GemmaRotaryEmbedding(nn.Module):
|
||||
@ -159,30 +101,6 @@ class GemmaRotaryEmbedding(nn.Module):
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
if config.hidden_activation is None:
|
||||
logger.warning_once(
|
||||
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
|
||||
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
|
||||
"`config.hidden_activation` if you want to override this behaviour.\n"
|
||||
"See https://github.com/huggingface/transformers/pull/29402 for more details."
|
||||
)
|
||||
config.hidden_activation = "gelu_pytorch_tanh"
|
||||
hidden_activation = config.hidden_activation
|
||||
self.act_fn = ACT2FN[hidden_activation]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
||||
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
@ -212,6 +130,30 @@ class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
||||
return cos, sin
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
if config.hidden_activation is None:
|
||||
logger.warning_once(
|
||||
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
|
||||
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
|
||||
"`config.hidden_activation` if you want to override this behaviour.\n"
|
||||
"See https://github.com/huggingface/transformers/pull/29402 for more details."
|
||||
)
|
||||
config.hidden_activation = "gelu_pytorch_tanh"
|
||||
hidden_activation = config.hidden_activation
|
||||
self.act_fn = ACT2FN[hidden_activation]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
@ -358,6 +300,94 @@ class GemmaAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class GemmaSdpaAttention(GemmaAttention):
|
||||
"""
|
||||
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from GemmaAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
class GemmaFlashAttention2(GemmaAttention):
|
||||
"""
|
||||
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
|
||||
@ -458,7 +488,6 @@ class GemmaFlashAttention2(GemmaAttention):
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
@ -468,92 +497,57 @@ class GemmaFlashAttention2(GemmaAttention):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class GemmaSdpaAttention(GemmaAttention):
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
min_dtype: float,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
):
|
||||
"""
|
||||
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
min_dtype (`float`):
|
||||
The minimum value representable with the dtype `dtype`.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
|
||||
# Adapted from GemmaAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
return causal_mask
|
||||
|
||||
|
||||
GEMMA_ATTENTION_CLASSES = {
|
||||
@ -567,9 +561,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = GemmaMLP(config)
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -830,9 +822,9 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
return_legacy_cache = False # noqa: F841
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
return_legacy_cache = True # noqa: F841
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
else:
|
||||
@ -975,6 +967,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
@ -1149,6 +1142,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
@ -1230,7 +1224,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
|
@ -21,8 +21,15 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import is_torchdynamo_compiling, logging
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
@ -32,14 +39,6 @@ from transformers.models.llama.modeling_llama import (
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...modeling_outputs import CausalLMOutputWithPast
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -216,6 +215,35 @@ class GemmaRotaryEmbedding(nn.Module):
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
||||
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def forward(self, x, position_ids):
|
||||
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
||||
position_ids = position_ids.float() / self.scaling_factor
|
||||
cos, sin = super().forward(x, position_ids)
|
||||
return cos, sin
|
||||
|
||||
|
||||
class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
||||
"""GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def forward(self, x, position_ids):
|
||||
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
|
||||
|
||||
cos, sin = super().forward(x, position_ids)
|
||||
return cos, sin
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@ -340,8 +368,95 @@ class GemmaAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# TODO felix: does this inheritance really work out in the end to GemmaFlashAttention2 inheriting form GemmaAttention?
|
||||
class GemmaFlashAttention2(LlamaFlashAttention2):
|
||||
class GemmaSdpaAttention(GemmaAttention):
|
||||
"""
|
||||
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from GemmaAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
class GemmaFlashAttention2(LlamaFlashAttention2, GemmaAttention):
|
||||
"""
|
||||
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
@ -427,12 +542,12 @@ class GemmaFlashAttention2(LlamaFlashAttention2):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
@ -442,7 +557,95 @@ class GemmaFlashAttention2(LlamaFlashAttention2):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
GEMMA_ATTENTION_CLASSES = {
|
||||
"eager": GemmaAttention,
|
||||
"flash_attention_2": GemmaFlashAttention2,
|
||||
"sdpa": GemmaSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class GemmaDecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__(config)
|
||||
self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
self.mlp = GemmaMLP(config)
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class GemmaModel(LlamaModel):
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet!
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -455,7 +658,7 @@ class GemmaModel(LlamaModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -513,22 +716,72 @@ class GemmaModel(LlamaModel):
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
|
||||
return super().forward(
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
use_cache,
|
||||
cache_position,
|
||||
input_ids=None,
|
||||
inputs_embeds=hidden_states,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
# Example where we ony modify the docstring and call super
|
||||
class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
|
||||
class GemmaForCausalLM(LlamaForCausalLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -542,18 +795,9 @@ class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
@ -589,10 +833,18 @@ class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -618,8 +870,14 @@ class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
|
||||
|
||||
|
||||
class GemmaForSequenceClassification(LlamaForSequenceClassification):
|
||||
pass
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
self.post_init()
|
||||
|
||||
|
||||
class GemmaForTokenClassification(LlamaForTokenClassification):
|
||||
pass
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
self.post_init()
|
@ -1,8 +1,8 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# This file was automatically generated from <path_to_modular_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_xxx.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
@ -19,7 +19,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class Gemma2Config(PretrainedConfig):
|
||||
@ -53,7 +55,8 @@ class Gemma2Config(PretrainedConfig):
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
@ -77,16 +80,17 @@ class Gemma2Config(PretrainedConfig):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma2Model, Gemma2Config
|
||||
>>> # Initializing a Gemma2 gemma2-9b style configuration
|
||||
>>> # Initializing a Gemma2 gemma2-7b style configuration
|
||||
>>> configuration = Gemma2Config()
|
||||
>>> # Initializing a model from the gemma2-9b style configuration
|
||||
>>> # Initializing a model from the gemma2-7b style configuration
|
||||
>>> model = Gemma2Model(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
@ -94,6 +98,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
|
||||
model_type = "gemma2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
cache_implementation = "hybrid"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -116,12 +121,19 @@ class Gemma2Config(PretrainedConfig):
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
query_pre_attn_scalar=224,
|
||||
sliding_window=4096,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
@ -130,23 +142,14 @@ class Gemma2Config(PretrainedConfig):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_activation = hidden_activation
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.hidden_activation = hidden_activation
|
||||
self.query_pre_attn_scalar = query_pre_attn_scalar
|
||||
self.sliding_window = sliding_window
|
||||
self.cache_implementation = "hybrid"
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
|
@ -1,8 +1,8 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# This file was automatically generated from <path_to_modular_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_xxx.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
@ -22,13 +22,14 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@ -39,7 +40,6 @@ from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
@ -49,66 +49,6 @@ from ...utils import (
|
||||
from .configuration_gemma2 import Gemma2Config
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
min_dtype: float,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
min_dtype (`float`):
|
||||
The minimum value representable with the dtype `dtype`.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class Gemma2RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@ -129,6 +69,24 @@ class Gemma2RMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
class Gemma2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
@ -191,21 +149,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Gemma2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
@ -253,12 +196,12 @@ class Gemma2Attention(nn.Module):
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||
self.rotary_emb = Gemma2RotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -502,9 +445,11 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
@ -515,6 +460,7 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
@ -533,6 +479,59 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
min_dtype: float,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
min_dtype (`float`):
|
||||
The minimum value representable with the dtype `dtype`.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
GEMMA2_ATTENTION_CLASSES = {
|
||||
"eager": Gemma2Attention,
|
||||
"flash_attention_2": Gemma2FlashAttention2,
|
||||
@ -543,19 +542,16 @@ GEMMA2_ATTENTION_CLASSES = {
|
||||
class Gemma2DecoderLayer(nn.Module):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = Gemma2MLP(config)
|
||||
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.config = config
|
||||
self.is_sliding = not bool(layer_idx % 2)
|
||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.sliding_window = config.sliding_window
|
||||
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -567,6 +563,25 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
@ -580,6 +595,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -711,13 +727,20 @@ GEMMA2_INPUTS_DOCSTRING = r"""
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`HybridCache`, *optional*):
|
||||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||
|
||||
Gemma 2 uses a unique cache class, [`HybridCache`], and does not guarantee full compatibility with other
|
||||
cache classes.
|
||||
Two formats are allowed:
|
||||
- a [`~cache_utils.Cache`] instance, see our
|
||||
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||
cache format.
|
||||
|
||||
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||
legacy cache format will be returned.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||
@ -812,8 +835,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# Instantiate an empty cache if needed.
|
||||
if use_cache and past_key_values is None:
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
@ -828,6 +850,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
@ -844,6 +867,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
@ -880,7 +904,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
@ -1009,6 +1032,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
|
||||
if self.training and self.config._attn_implementation != "eager":
|
||||
logger.warning_once(
|
||||
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
||||
@ -1187,10 +1211,10 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
|
@ -13,30 +13,41 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.models.gemma.configuration_gemma import GemmaConfig
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from ...utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from ..gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaDecoderLayer,
|
||||
GemmaForCausalLM,
|
||||
GemmaForSequenceClassification,
|
||||
GemmaForTokenClassification,
|
||||
GemmaModel,
|
||||
GemmaPreTrainedModel,
|
||||
GemmaRMSNorm,
|
||||
_prepare_4d_causal_attention_mask_with_cache_position,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
@ -45,33 +56,230 @@ if is_flash_attn_2_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma2Config(GemmaConfig):
|
||||
cache_implementation = "hybrid" # TODO this is not properly ported, but cls attr is better
|
||||
class Gemma2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma2-7B.
|
||||
e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256000):
|
||||
Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Gemma2Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 2):
|
||||
Beginning of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma2Model, Gemma2Config
|
||||
>>> # Initializing a Gemma2 gemma2-7b style configuration
|
||||
>>> configuration = Gemma2Config()
|
||||
>>> # Initializing a model from the gemma2-7b style configuration
|
||||
>>> model = Gemma2Model(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
cache_implementation = "hybrid"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
hidden_size=3072,
|
||||
intermediate_size=24576,
|
||||
num_hidden_layers=28,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
query_pre_attn_scalar=224,
|
||||
sliding_window=4096,
|
||||
final_logit_softcapping=30.0,
|
||||
**super_kwargs,
|
||||
attn_logit_softcapping=50.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(self, **super_kwargs)
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_activation = hidden_activation
|
||||
self.query_pre_attn_scalar = query_pre_attn_scalar
|
||||
self.sliding_window = sliding_window
|
||||
self.cache_implementation = "hybrid"
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
|
||||
|
||||
class Gemma2RMSNorm(GemmaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Gemma2Attention(GemmaAttention):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__(config, layer_idx)
|
||||
self.scaling = config.query_pre_attn_scalar**-0.5
|
||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"sliding_window": self.sliding_window,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if self.config.attn_logit_softcapping is not None:
|
||||
attn_weights = attn_weights / self.config.attn_logit_softcapping
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * self.config.attn_logit_softcapping
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Gemma2FlashAttention2(Gemma2Attention):
|
||||
@ -119,9 +327,19 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"sliding_window": self.sliding_window,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
if attention_mask is not None:
|
||||
seq_len = attention_mask.shape[1]
|
||||
key_states = key_states[:, :, :seq_len]
|
||||
value_states = value_states[:, :, :seq_len]
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
@ -156,7 +374,6 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
@ -166,7 +383,9 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=self.scaling,
|
||||
is_causal=self.is_causal,
|
||||
sliding_window=self.sliding_window,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
@ -227,7 +446,12 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"sliding_window": self.sliding_window,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
@ -269,8 +493,9 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
||||
class Gemma2DecoderLayer(GemmaDecoderLayer):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
self.is_sliding = bool(layer_idx % 2)
|
||||
self.config = config
|
||||
self.is_sliding = not bool(layer_idx % 2)
|
||||
self.mlp = Gemma2MLP(config)
|
||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.sliding_window = config.sliding_window
|
||||
@ -286,11 +511,18 @@ class Gemma2DecoderLayer(GemmaDecoderLayer):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
attention_mask = attention_mask * torch.tril(
|
||||
torch.ones_like(attention_mask), diagonal=(self.sliding_window - cache_position[-1])
|
||||
)
|
||||
if cache_position[0] > 0:
|
||||
# Flash-attn is a 2D tensor
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if past_key_value is not None: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@ -326,13 +558,38 @@ class Gemma2DecoderLayer(GemmaDecoderLayer):
|
||||
return outputs
|
||||
|
||||
|
||||
class Gemma2Model(GemmaModel):
|
||||
class Gemma2PreTrainedModel(GemmaPreTrainedModel):
|
||||
_supports_quantized_cache = False
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
|
||||
"""
|
||||
Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
|
||||
SDPA reduces the model performance on Gemma2 because of the logits softcapping.
|
||||
"""
|
||||
config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
|
||||
|
||||
# if using the default path -> swap sdpa by eager
|
||||
if not hard_check_only and config._attn_implementation == "sdpa":
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class Gemma2Model(GemmaModel, Gemma2PreTrainedModel):
|
||||
def __init__(self, config: Gemma2Config):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -361,8 +618,21 @@ class Gemma2Model(GemmaModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
device=self.device,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
@ -437,50 +707,50 @@ class Gemma2Model(GemmaModel):
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
past_key_values: HybridCache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
|
||||
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
|
||||
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
|
||||
# as it doesn't cause dynamic control issues.
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_length()
|
||||
else:
|
||||
target_length = attention_mask.shape[-1]
|
||||
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
min_dtype=min_dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
|
||||
class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = Gemma2Model(config)
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -488,18 +758,9 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
@ -514,12 +775,17 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
|
||||
if self.training and self.config._attn_implementation != "eager":
|
||||
logger.warning_once(
|
||||
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
||||
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
||||
)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -535,15 +801,23 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logits = logits / self.config.final_logit_softcapping
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.config.final_logit_softcapping
|
||||
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = logits.float()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -567,10 +841,94 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
|
||||
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
|
||||
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
|
||||
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
|
||||
# which retriggers a capture.
|
||||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
and not self.config._attn_implementation == "flash_attention_2"
|
||||
):
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_length(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
min_dtype=min_dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if num_logits_to_keep is not None:
|
||||
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
|
||||
class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
|
||||
pass
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = Gemma2Model(config)
|
||||
self.post_init()
|
||||
|
||||
|
||||
class Gemma2ForTokenClassification(GemmaForTokenClassification):
|
||||
pass
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = Gemma2Model(config)
|
||||
self.post_init()
|
@ -1,8 +1,8 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# This file was automatically generated from <path_to_modular_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_xxx.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
@ -24,9 +24,7 @@ from typing import Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
from ...utils import (
|
||||
logging,
|
||||
)
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
@ -36,8 +34,8 @@ logger = logging.get_logger(__name__)
|
||||
class InstructBlipVideoVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InstructBlipVideoVisionModel`]. It is used to
|
||||
instantiate a Instructblipvideo vision encoder according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration defaults will yield a similar configuration to that of the Instructblipvideo
|
||||
instantiate a InstructBlipVideo vision encoder according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration defaults will yield a similar configuration to that of the InstructBlipVideo
|
||||
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
@ -58,7 +56,7 @@ class InstructBlipVideoVisionConfig(PretrainedConfig):
|
||||
The size (resolution) of each patch.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. to 1e-5): The epsilon used by the layer
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. to 1e-5): The epsilon used by the layer
|
||||
normalization layers.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
@ -137,9 +135,9 @@ class InstructBlipVideoVisionConfig(PretrainedConfig):
|
||||
class InstructBlipVideoQFormerConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InstructBlipVideoQFormerModel`]. It is used to
|
||||
instantiate a Instructblipvideo Querying Transformer (Q-Former) model according to the specified arguments, defining the
|
||||
instantiate a InstructBlipVideo Querying Transformer (Q-Former) model according to the specified arguments, defining the
|
||||
model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the Instructblipvideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
|
||||
the InstructBlipVideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
|
||||
architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
|
||||
Read the documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
@ -189,7 +187,7 @@ class InstructBlipVideoQFormerConfig(PretrainedConfig):
|
||||
```python
|
||||
>>> from transformers import InstructBlipVideoQFormerConfig, InstructBlipVideoQFormerModel
|
||||
|
||||
>>> # Initializing a Instructblipvideo Salesforce/instruct-blip-flan-t5 style configuration
|
||||
>>> # Initializing a InstructBlipVideo Salesforce/instruct-blip-flan-t5 style configuration
|
||||
>>> configuration = InstructBlipVideoQFormerConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
|
||||
@ -360,7 +358,7 @@ class InstructBlipVideoConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a Instructblipvideo vision model, Q-Former and
|
||||
Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
|
||||
language model configurations.
|
||||
|
||||
Returns:
|
||||
|
@ -1,8 +1,8 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# This file was automatically generated from <path_to_modular_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_xxx.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
@ -19,7 +19,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
@ -354,6 +353,21 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
@ -371,6 +385,71 @@ INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
"""
|
||||
|
||||
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
|
||||
[`InstructBlipVideoProcessor.__call__`] for details.
|
||||
|
||||
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
||||
to serve as text prompt, which the Q-Former model will encode.
|
||||
|
||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||
details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
|
||||
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
||||
provided to serve as text prompt, which the language model can continue.
|
||||
|
||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||
details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
||||
encoder-decoder language model (like T5) is used.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
|
||||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
|
||||
Only relevant in case an encoder-decoder language model (like T5) is used.
|
||||
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlipVideo
|
||||
class InstructBlipVideoEncoder(nn.Module):
|
||||
@ -459,87 +538,6 @@ class InstructBlipVideoEncoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
|
||||
[`InstructBlipVideoProcessor.__call__`] for details.
|
||||
|
||||
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
||||
to serve as text prompt, which the Q-Former model will encode.
|
||||
|
||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||
details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
|
||||
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
||||
provided to serve as text prompt, which the language model can continue.
|
||||
|
||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||
details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
||||
encoder-decoder language model (like T5) is used.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
|
||||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
|
||||
Only relevant in case an encoder-decoder language model (like T5) is used.
|
||||
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlipVideo, BLIP->INSTRUCTBLIPVIDEO
|
||||
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
@ -1089,7 +1087,7 @@ class InstructBlipVideoQFormerEmbeddings(nn.Module):
|
||||
|
||||
class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
|
||||
"""
|
||||
Querying Transformer (Q-Former), used in Instructblipvideo. Slightly modified from BLIP-2 as it also takes the
|
||||
Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the
|
||||
instruction as input.
|
||||
"""
|
||||
|
||||
@ -1285,7 +1283,7 @@ class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Instructblipvideo Model for generating text given an image and an optional text prompt. The model consists of a vision
|
||||
InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
|
||||
encoder, Querying Transformer (Q-Former) and a language model.
|
||||
|
||||
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
||||
@ -1358,7 +1356,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
hf_device_map = self.hf_device_map
|
||||
|
||||
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
||||
# warn users about unexpected behavior when using multi-GPU + Instructblipvideo + `accelerate`.
|
||||
# warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
|
||||
logger.warning(
|
||||
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
||||
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
||||
@ -1505,7 +1503,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
@ -1584,7 +1581,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
r"""
|
||||
Overrides `generate` function to be able to use the model as a conditional generator.
|
||||
|
||||
Args:
|
||||
|
@ -21,32 +21,18 @@ import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.models.instructblip.configuration_instructblip import (
|
||||
InstructBlipConfig,
|
||||
InstructBlipQFormerConfig,
|
||||
InstructBlipVisionConfig,
|
||||
)
|
||||
from transformers.models.instructblip.modeling_instructblip import (
|
||||
InstructBlipAttention,
|
||||
InstructBlipEncoder,
|
||||
InstructBlipEncoderLayer,
|
||||
InstructBlipForConditionalGeneration,
|
||||
InstructBlipForConditionalGenerationModelOutput,
|
||||
InstructBlipMLP,
|
||||
InstructBlipPreTrainedModel,
|
||||
InstructBlipQFormerAttention,
|
||||
InstructBlipQFormerEmbeddings,
|
||||
InstructBlipQFormerEncoder,
|
||||
InstructBlipQFormerIntermediate,
|
||||
InstructBlipQFormerLayer,
|
||||
InstructBlipQFormerModel,
|
||||
InstructBlipQFormerOutput,
|
||||
InstructBlipQFormerSelfOutput,
|
||||
InstructBlipVisionEmbeddings,
|
||||
InstructBlipVisionModel,
|
||||
)
|
||||
|
||||
from ...generation import GenerationMixin
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -60,8 +46,124 @@ class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoConfig(InstructBlipConfig):
|
||||
pass
|
||||
class InstructBlipVideoConfig(PretrainedConfig):
|
||||
r"""
|
||||
[`InstructBlipVideoConfig`] is the configuration class to store the configuration of a
|
||||
[`InstructBlipVideoForConditionalGeneration`]. It is used to instantiate a Instructblipvideo model according to the specified
|
||||
arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
|
||||
the defaults will yield a similar configuration to that of the Instructblipvideo
|
||||
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`InstructBlipVideoVisionConfig`].
|
||||
qformer_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
|
||||
text_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
||||
num_query_tokens (`int`, *optional*, defaults to 32):
|
||||
The number of query tokens passed through the Transformer.
|
||||
|
||||
video_token_index (`int`, *optional*):
|
||||
Token index of special video token.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... InstructBlipVideoVisionConfig,
|
||||
... InstructBlipVideoQFormerConfig,
|
||||
... OPTConfig,
|
||||
... InstructBlipVideoConfig,
|
||||
... InstructBlipVideoForConditionalGeneration,
|
||||
... )
|
||||
|
||||
>>> # Initializing a InstructBlipVideoConfig with Salesforce/instruct-blip-flan-t5 style configuration
|
||||
>>> configuration = InstructBlipVideoConfig()
|
||||
|
||||
>>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
|
||||
>>> model = InstructBlipVideoForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
|
||||
>>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PretrainedConfig
|
||||
|
||||
>>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
|
||||
>>> vision_config = InstructBlipVideoVisionConfig()
|
||||
>>> qformer_config = InstructBlipVideoQFormerConfig()
|
||||
>>> text_config = OPTConfig()
|
||||
|
||||
>>> config = InstructBlipVideoConfig.from_text_vision_configs(vision_config, qformer_config, text_config)
|
||||
```"""
|
||||
|
||||
model_type = "instructblipvideo"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
qformer_config=None,
|
||||
text_config=None,
|
||||
num_query_tokens=32,
|
||||
video_token_index=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
logger.info("vision_config is None. initializing the InstructBlipVideoVisionConfig with default values.")
|
||||
|
||||
if qformer_config is None:
|
||||
qformer_config = {}
|
||||
logger.info("qformer_config is None. Initializing the InstructBlipVideoQFormerConfig with default values.")
|
||||
|
||||
if text_config is None:
|
||||
text_config = {}
|
||||
logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
|
||||
|
||||
self.vision_config = InstructBlipVideoVisionConfig(**vision_config)
|
||||
self.qformer_config = InstructBlipVideoQFormerConfig(**qformer_config)
|
||||
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
|
||||
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
||||
|
||||
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
||||
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
||||
|
||||
self.num_query_tokens = num_query_tokens
|
||||
self.video_token_index = video_token_index
|
||||
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
||||
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
self.initializer_factor = 1.0
|
||||
self.initializer_range = 0.02
|
||||
|
||||
@classmethod
|
||||
def from_vision_qformer_text_configs(
|
||||
cls,
|
||||
vision_config: InstructBlipVideoVisionConfig,
|
||||
qformer_config: InstructBlipVideoQFormerConfig,
|
||||
text_config: PretrainedConfig,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
|
||||
language model configurations.
|
||||
|
||||
Returns:
|
||||
[`InstructBlipVideoConfig`]: An instance of a configuration object
|
||||
"""
|
||||
|
||||
return cls(
|
||||
vision_config=vision_config.to_dict(),
|
||||
qformer_config=qformer_config.to_dict(),
|
||||
text_config=text_config.to_dict(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -69,67 +171,7 @@ class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForCondit
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoVisionEmbeddings(InstructBlipVisionEmbeddings):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoAttention(InstructBlipAttention):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoMLP(InstructBlipMLP):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoEncoderLayer(InstructBlipEncoderLayer):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoPreTrainedModel(InstructBlipPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoEncoder(InstructBlipEncoder):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoVisionModel(InstructBlipVisionModel):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerSelfOutput(InstructBlipQFormerSelfOutput):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerAttention(InstructBlipQFormerAttention):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerIntermediate(InstructBlipQFormerIntermediate):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerOutput(InstructBlipQFormerOutput):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerLayer(InstructBlipQFormerLayer):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerEncoder(InstructBlipQFormerEncoder):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerEmbeddings(InstructBlipQFormerEmbeddings):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerModel(InstructBlipQFormerModel):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration, GenerationMixin):
|
||||
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
@ -146,15 +188,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
|
||||
1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
||||
config.vocab_size]`
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
|
||||
>>> import torch
|
||||
@ -339,7 +372,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
r"""
|
||||
Overrides `generate` function to be able to use the model as a conditional generator.
|
||||
|
||||
Args:
|
@ -1,8 +1,8 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# This file was automatically generated from <path_to_modular_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_xxx.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
@ -20,17 +20,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ...utils import (
|
||||
logging,
|
||||
)
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaNextVideoConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`LlavaNextVideoForConditionalGeneration`]. It is used to instantiate an
|
||||
|
@ -1,8 +1,8 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# This file was automatically generated from <path_to_modular_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_xxx.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
@ -19,7 +19,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@ -130,6 +129,12 @@ def unpad_image(tensor, original_size):
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
"""
|
||||
if not isinstance(original_size, (list, tuple)):
|
||||
if not isinstance(original_size, (torch.Tensor, np.ndarray)):
|
||||
raise TypeError(
|
||||
f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
|
||||
)
|
||||
original_size = original_size.tolist()
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
@ -180,6 +185,7 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
|
||||
video_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
|
||||
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
@ -191,6 +197,7 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
video_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@ -455,7 +462,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration._merge_input_ids_with_image_features
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
image_features,
|
||||
@ -695,7 +701,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
|
||||
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
|
||||
|
||||
def pack_image_features(self, image_features, image_sizes, image_newline=None):
|
||||
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
|
||||
"""
|
||||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||
|
||||
@ -704,6 +710,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
List of image feature tensor, each contains all the visual feature of all patches.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_select_strategy (`str`)
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||||
New line embedding vector.
|
||||
Returns:
|
||||
@ -718,8 +726,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
expected_num_patches = height * width
|
||||
elif vision_feature_select_strategy == "full":
|
||||
expected_num_patches = height * width + 1
|
||||
if expected_num_patches != base_image_feature.shape[0]:
|
||||
raise ValueError("The number of patches is not consistent with the image size.")
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
|
@ -21,15 +21,13 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
LlavaNextCausalLMOutputWithPast,
|
||||
LlavaNextForConditionalGeneration,
|
||||
LlavaNextMultiModalProjector,
|
||||
image_size_to_num_patches,
|
||||
)
|
||||
|
||||
from ...generation import GenerationMixin
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@ -56,18 +54,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
The config object or dictionary of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
video_token_index (`int`, *optional*, defaults to 32000):
|
||||
The video token index to encode the image prompt.
|
||||
image_token_index (`int`, *optional*, defaults to 32001):
|
||||
The image token index to encode the image prompt.
|
||||
spatial_pool_mode (`str`, *optional*, defaults to `"average"`):
|
||||
Pooling mode to use for videos. Can be "average", "max" or "conv".
|
||||
spatial_pool_stride (`int`, *optional*, defaults to 2):
|
||||
Stride used in the pooling layer for videos.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
Sequence length of one image embedding.
|
||||
video_seq_length (`int`, *optional*, defaults to 288):
|
||||
Sequence length of one video embedding.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
The activation function used by the multimodal projector.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||
@ -81,6 +69,16 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
of the form `(height, width)`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
video_token_index (`int`, *optional*, defaults to 32000):
|
||||
The video token index to encode the image prompt.
|
||||
spatial_pool_mode (`str`, *optional*, defaults to `"average"`):
|
||||
Pooling mode to use for videos. Can be "average", "max" or "conv".
|
||||
spatial_pool_stride (`int`, *optional*, defaults to 2):
|
||||
Stride used in the pooling layer for videos.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
Sequence length of one image embedding.
|
||||
video_seq_length (`int`, *optional*, defaults to 288):
|
||||
Sequence length of one video embedding.
|
||||
|
||||
Example:
|
||||
|
||||
@ -178,7 +176,13 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
|
||||
@dataclass
|
||||
class LlavaNextVideoCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
|
||||
pass
|
||||
"""
|
||||
video_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
|
||||
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
video_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class LlavaNextVideoPooler(nn.Module):
|
||||
@ -215,11 +219,7 @@ class LlavaNextVideoPooler(nn.Module):
|
||||
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector):
|
||||
pass
|
||||
|
||||
|
||||
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration, GenerationMixin):
|
||||
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
|
||||
super().__init__(config, **super_kwargs)
|
||||
self.vision_resampler = LlavaNextVideoPooler(config)
|
||||
@ -287,6 +287,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -298,6 +300,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -329,7 +335,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
||||
... frames.append(frame)
|
||||
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||
|
||||
>>> model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", device_map="auto)
|
||||
>>> model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", device_map="auto")
|
||||
>>> processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
|
||||
|
||||
>>> prompt = "USER: <video>\nWhy is this video funny? ASSISTANT:"
|
||||
@ -499,6 +505,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
@ -529,6 +536,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
video_hidden_states=video_features if pixel_values_videos is not None else None,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
@ -541,6 +550,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
||||
image_sizes=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
if input_ids is not None:
|
||||
@ -560,6 +570,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -182,6 +182,7 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
|
||||
video_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
|
||||
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
|
76
utils/check_modular_conversion.py
Normal file
76
utils/check_modular_conversion.py
Normal file
@ -0,0 +1,76 @@
|
||||
import argparse
|
||||
import difflib
|
||||
import glob
|
||||
import logging
|
||||
from io import StringIO
|
||||
|
||||
# Console for rich printing
|
||||
from modular_model_converter import convert_modular_file
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
console = Console()
|
||||
|
||||
|
||||
def process_file(modular_file_path, generated_modeling_content, file_type="modeling_", fix_and_overwrite=False):
|
||||
file_path = modular_file_path.replace("modular_", f"{file_type}_")
|
||||
# Read the actual modeling file
|
||||
with open(file_path, "r") as modeling_file:
|
||||
content = modeling_file.read()
|
||||
output_buffer = StringIO(generated_modeling_content[file_type][0])
|
||||
output_buffer.seek(0)
|
||||
output_content = output_buffer.read()
|
||||
diff = difflib.unified_diff(
|
||||
output_content.splitlines(),
|
||||
content.splitlines(),
|
||||
fromfile=f"{file_path}_generated",
|
||||
tofile=f"{file_path}",
|
||||
lineterm="",
|
||||
)
|
||||
diff_list = list(diff)
|
||||
# Check for differences
|
||||
if diff_list:
|
||||
if fix_and_overwrite:
|
||||
with open(file_path, "w") as modeling_file:
|
||||
modeling_file.write(generated_modeling_content[file_type][0])
|
||||
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
|
||||
else:
|
||||
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
|
||||
diff_text = "\n".join(diff_list)
|
||||
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
|
||||
console.print(syntax)
|
||||
return 1
|
||||
else:
|
||||
console.print(f"[bold green]No differences found for {file_path}.[/bold green]")
|
||||
return 0
|
||||
|
||||
|
||||
def compare_files(modular_file_path, fix_and_overwrite=False):
|
||||
# Generate the expected modeling content
|
||||
generated_modeling_content = convert_modular_file(modular_file_path)
|
||||
diff = 0
|
||||
for file_type in generated_modeling_content.keys():
|
||||
diff += process_file(modular_file_path, generated_modeling_content, file_type, fix_and_overwrite)
|
||||
return diff
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
|
||||
parser.add_argument(
|
||||
"--files", default=["all"], type=list, nargs="+", help="List of modular_xxx.py files to compare."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.files == ["all"]:
|
||||
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||
non_matching_files = 0
|
||||
for modular_file_path in args.files:
|
||||
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
||||
|
||||
if non_matching_files and not args.fix_and_overwrite:
|
||||
raise ValueError("Some diff and their modeling code did not match.")
|
69
utils/create_dependency_mapping.py
Normal file
69
utils/create_dependency_mapping.py
Normal file
@ -0,0 +1,69 @@
|
||||
import ast
|
||||
from collections import defaultdict, deque
|
||||
|
||||
|
||||
# Function to perform topological sorting
|
||||
def topological_sort(dependencies):
|
||||
# Create a graph and in-degree count for each node
|
||||
graph = defaultdict(list)
|
||||
in_degree = defaultdict(int)
|
||||
|
||||
# Build the graph
|
||||
for node, deps in dependencies.items():
|
||||
for dep in deps:
|
||||
graph[dep].append(node) # node depends on dep
|
||||
in_degree[node] += 1 # increase in-degree of node
|
||||
|
||||
# Add all nodes with zero in-degree to the queue
|
||||
zero_in_degree_queue = deque([node for node in dependencies if in_degree[node] == 0])
|
||||
|
||||
sorted_list = []
|
||||
# Perform topological sorting
|
||||
while zero_in_degree_queue:
|
||||
current = zero_in_degree_queue.popleft()
|
||||
sorted_list.append(current)
|
||||
|
||||
# For each node that current points to, reduce its in-degree
|
||||
for neighbor in graph[current]:
|
||||
in_degree[neighbor] -= 1
|
||||
if in_degree[neighbor] == 0:
|
||||
zero_in_degree_queue.append(neighbor)
|
||||
|
||||
# Handle nodes that have no dependencies and were not initially part of the loop
|
||||
for node in dependencies:
|
||||
if node not in sorted_list:
|
||||
sorted_list.append(node)
|
||||
|
||||
return sorted_list
|
||||
|
||||
|
||||
# Function to extract class and import info from a file
|
||||
def extract_classes_and_imports(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
tree = ast.parse(file.read(), filename=file_path)
|
||||
imports = set()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
module = node.module if isinstance(node, ast.ImportFrom) else None
|
||||
if module and "transformers" in module:
|
||||
imports.add(module)
|
||||
return imports
|
||||
|
||||
|
||||
# Function to map dependencies between classes
|
||||
def map_dependencies(py_files):
|
||||
dependencies = defaultdict(set)
|
||||
# First pass: Extract all classes and map to files
|
||||
for file_path in py_files:
|
||||
dependencies[file_path].add(None)
|
||||
class_to_file = extract_classes_and_imports(file_path)
|
||||
for module in class_to_file:
|
||||
dependencies[file_path].add(module)
|
||||
return dependencies
|
||||
|
||||
|
||||
def find_priority_list(py_files):
|
||||
dependencies = map_dependencies(py_files)
|
||||
ordered_classes = topological_sort(dependencies)
|
||||
return ordered_classes[::-1]
|
@ -20,21 +20,23 @@ from typing import Dict
|
||||
|
||||
import libcst as cst
|
||||
from check_copies import run_ruff
|
||||
from create_dependency_mapping import find_priority_list
|
||||
from libcst import ClassDef, CSTTransformer, CSTVisitor
|
||||
from libcst import matchers as m
|
||||
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
|
||||
|
||||
from transformers import logging
|
||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# This file was automatically generated from <path_to_modular_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_xxx.py file directly. One of our CI enforces this
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
"""
|
||||
|
||||
@ -82,12 +84,16 @@ class ClassFinder(CSTVisitor):
|
||||
self.function_def = {} # stores global scope function definition
|
||||
self.assignments = {} # LLAMA_DOCSTRING
|
||||
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
|
||||
self.first_lvl_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
|
||||
# fmt: on
|
||||
|
||||
def _update_class_dependency(self, name, value):
|
||||
"""Update the dependency mapping for `name` with `value` by appending the previous
|
||||
dependencies to the new `value`.
|
||||
"""
|
||||
dep = set(self.first_lvl_dependency_mapping.get(name, set())) | set({value})
|
||||
self.first_lvl_dependency_mapping[name] = dep
|
||||
|
||||
dep = set(self.class_dependency_mapping.get(value, set()))
|
||||
dep |= set(self.class_dependency_mapping.get(name, {})) | set({value})
|
||||
self.class_dependency_mapping[name] = dep
|
||||
@ -146,7 +152,16 @@ class ClassFinder(CSTVisitor):
|
||||
def leave_Decorator(self, node):
|
||||
if hasattr(node.decorator, "args"):
|
||||
for k in node.decorator.args:
|
||||
if k.value.value in self.assignments:
|
||||
if m.matches(k.value, m.Call(func=m.Attribute(value=m.Name()))): # and k.value.func.value.value:
|
||||
if k.value.func.value.value not in self.assignments:
|
||||
raise ValueError(
|
||||
f"We detected a call to {k.value.func.value.value}, but it was not assigned. See the list of assigments {self.assignments.keys()}"
|
||||
)
|
||||
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
|
||||
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
|
||||
self._update_class_dependency(name, k.value.func.value.value)
|
||||
elif m.matches(k, m.Arg(value=m.Name())) and k.value.value in self.assignments:
|
||||
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
|
||||
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
|
||||
@ -178,6 +193,10 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
self.default_name = "".join(x.title() for x in new_name.split("_"))
|
||||
if self.new_name in CONFIG_MAPPING_NAMES:
|
||||
self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
|
||||
"Config", ""
|
||||
) # the best source of truth for class names. Could also just use the ones de
|
||||
self.patterns = {
|
||||
old_name: new_name,
|
||||
old_name.upper(): new_name.upper(),
|
||||
@ -193,7 +212,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
|
||||
def replace(match):
|
||||
word = match.group(0)
|
||||
return self.patterns.get(word, self.default_name)
|
||||
result = self.patterns.get(word, self.default_name)
|
||||
return result
|
||||
|
||||
return compiled_regex.sub(replace, text)
|
||||
|
||||
@ -227,35 +247,102 @@ DOCSTRING_NODE = m.SimpleStatementLine(
|
||||
)
|
||||
|
||||
|
||||
def SUPER_CALL_NODE(func_name):
|
||||
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
||||
|
||||
|
||||
def get_docstring_indent(docstring):
|
||||
# Match the first line after the opening triple quotes
|
||||
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
|
||||
if match:
|
||||
# Return the indentation spaces captured
|
||||
return len(match.group(1))
|
||||
return 0
|
||||
|
||||
|
||||
def merge_docstrings(original_docstring, updated_docstring):
|
||||
# indent_level = get_docstring_indent(updated_docstring)
|
||||
original_level = get_docstring_indent(original_docstring)
|
||||
if " Args:\n " not in updated_docstring:
|
||||
# Split the docstring at the example section, assuming `"""` is used to define the docstring
|
||||
parts = original_docstring.split("```")
|
||||
if "```" in updated_docstring and len(parts) > 1:
|
||||
updated_docstring = updated_docstring.lstrip('r"')
|
||||
new_parts = updated_docstring.split("```")
|
||||
if len(new_parts) != 3:
|
||||
raise ValueError("There should only be one example, and it should have opening and closing '```'")
|
||||
parts[1] = new_parts[1]
|
||||
updated_docstring = "".join(
|
||||
[
|
||||
parts[0].rstrip(" \n") + new_parts[0],
|
||||
f"\n{original_level*' '}```",
|
||||
parts[1],
|
||||
"```",
|
||||
parts[2],
|
||||
]
|
||||
)
|
||||
elif updated_docstring not in original_docstring:
|
||||
# add tabulation if we are at the lowest level.
|
||||
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
|
||||
updated_docstring = updated_docstring.replace("\n ", "\n ")
|
||||
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
|
||||
return updated_docstring
|
||||
|
||||
|
||||
class SuperTransformer(cst.CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||
|
||||
def __init__(self, python_module: cst.Module, original_methods, updated_methods):
|
||||
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name=""):
|
||||
self.python_module = python_module
|
||||
self.original_methods = original_methods
|
||||
self.updated_methods = updated_methods
|
||||
self.all_assign_target = {}
|
||||
self.deleted_targets = {} # child node can delete some arguments
|
||||
self.class_name = class_name
|
||||
|
||||
def update_body(self, existing_body, new_statements):
|
||||
"""
|
||||
Helper method to update the body by removing duplicates before adding new statements.
|
||||
`existing_body` is the body of the original method, the parent class
|
||||
`new_statements` are the additional statements
|
||||
"""
|
||||
deduplicated_new_body = []
|
||||
existing_nodes = set()
|
||||
for node in new_statements:
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
target = self.python_module.code_for_node(node.body[0].targets[0].target)
|
||||
self.all_assign_target[target] = node
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
|
||||
target = self.python_module.code_for_node(node.body[0].target)
|
||||
self.deleted_targets[target] = node
|
||||
continue
|
||||
|
||||
for stmt in existing_body:
|
||||
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
|
||||
if target in self.deleted_targets:
|
||||
logger.warning(f"Deleted the assign for {target}")
|
||||
continue
|
||||
if target in self.all_assign_target:
|
||||
stmt = self.all_assign_target[target]
|
||||
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
|
||||
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
||||
deduplicated_new_body.append(stmt)
|
||||
existing_nodes.add(comment_less_code)
|
||||
|
||||
for node in new_statements:
|
||||
code = self.python_module.code_for_node(node)
|
||||
comment_less_code = re.sub(r"#.*", "", code).strip()
|
||||
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
||||
if (
|
||||
node not in deduplicated_new_body
|
||||
and "super().__init__" not in comment_less_code
|
||||
and comment_less_code not in existing_nodes
|
||||
):
|
||||
if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
|
||||
# HACK here to fix the pos_init() that has to be last we kinda do this.
|
||||
deduplicated_new_body = deduplicated_new_body[:-1] + [node] + deduplicated_new_body[-1:]
|
||||
existing_nodes.add(comment_less_code)
|
||||
for stmt in existing_body:
|
||||
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
|
||||
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
||||
if comment_less_code not in existing_nodes:
|
||||
if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
|
||||
continue
|
||||
deduplicated_new_body.append(stmt)
|
||||
existing_nodes.add(stmt)
|
||||
else:
|
||||
logger.info(f"\nFound duplicate {self.python_module.code_for_node(stmt)}")
|
||||
return deduplicated_new_body
|
||||
|
||||
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
|
||||
@ -263,26 +350,37 @@ class SuperTransformer(cst.CSTTransformer):
|
||||
to super().func_name() with the source code of the parent class' `func_name`.
|
||||
It keeps everything that is defined before `super().func_name()`.
|
||||
"""
|
||||
new_body = []
|
||||
self.has_docstring = False
|
||||
for expr in node.body:
|
||||
self.has_docstring = m.matches(node.body[0], DOCSTRING_NODE)
|
||||
parent_has_docstring = False
|
||||
if func_name in self.original_methods:
|
||||
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
|
||||
new_body = []
|
||||
has_super_call = False
|
||||
for idx, expr in enumerate(node.body):
|
||||
if m.matches(
|
||||
expr,
|
||||
m.SimpleStatementLine(
|
||||
body=[
|
||||
m.Return(
|
||||
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
||||
)
|
||||
| m.Expr(
|
||||
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
||||
)
|
||||
]
|
||||
body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]
|
||||
),
|
||||
):
|
||||
if idx != 0 and func_name == "__init__":
|
||||
raise ValueError(f"The call to super() in {self.class_name} should be at the top of the init")
|
||||
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
|
||||
has_super_call = True
|
||||
elif m.matches(expr, DOCSTRING_NODE):
|
||||
self.has_docstring = True
|
||||
if parent_has_docstring: # actually here we ought to de-duplicate?
|
||||
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
|
||||
updated_docstring = expr.body[0].value.value
|
||||
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
|
||||
else:
|
||||
new_node = [expr]
|
||||
new_body.extend(new_node)
|
||||
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
|
||||
new_body.append(expr)
|
||||
if not self.has_docstring and parent_has_docstring:
|
||||
new_body = [self.original_methods[func_name].body.body[0]] + new_body
|
||||
return node.with_changes(body=new_body)
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
||||
@ -330,14 +428,22 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
||||
| ```
|
||||
"""
|
||||
original_node = class_finder.classes[class_name]
|
||||
original_methods = {f.name.value if hasattr(f, "name") else f: f for f in original_node.body.body}
|
||||
updated_methods = {f.name.value if hasattr(f, "name") else f: f for f in updated_node.body.body}
|
||||
original_methods = {
|
||||
f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f
|
||||
for f in original_node.body.body
|
||||
}
|
||||
updated_methods = {
|
||||
f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f
|
||||
for f in updated_node.body.body
|
||||
}
|
||||
end_meth = []
|
||||
|
||||
assign_targets = {}
|
||||
docstring_node = []
|
||||
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
|
||||
for func in original_node.body.body:
|
||||
name = func.name.value if hasattr(func, "name") else func
|
||||
if name in updated_methods and updated_methods[name] is not None:
|
||||
name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func)
|
||||
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
|
||||
new_params = updated_methods[name].params
|
||||
# Replace the method in the replacement class, preserving decorators
|
||||
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
|
||||
@ -348,22 +454,61 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
||||
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
|
||||
)
|
||||
func = func.with_changes(body=updated_methods[name].body, params=new_params)
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
|
||||
assign_targets[target] = func
|
||||
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
|
||||
target = class_finder.python_module.code_for_node(func.body[0].target)
|
||||
assign_targets[target] = func
|
||||
elif m.matches(func, DOCSTRING_NODE):
|
||||
docstring_node = [func]
|
||||
else:
|
||||
end_meth.append(func)
|
||||
|
||||
# Port new methods that are defined only in diff-file and append at the end
|
||||
for name, func in updated_methods.items():
|
||||
# Port new methods that are defined only in modular-file and append at the end
|
||||
for func in updated_node.body.body:
|
||||
name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func)
|
||||
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
|
||||
# Extract the original docstring
|
||||
updated_docstring = func.body[0].value.value
|
||||
original_docstring = docstring_node[0].body[0].value.value
|
||||
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||
# Update the docstring in the original function
|
||||
docstring_node = [
|
||||
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
|
||||
]
|
||||
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
|
||||
end_meth.append(func)
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
# TODO we only use single assign might cause issues
|
||||
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
|
||||
assign_targets[target] = func
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
|
||||
target = class_finder.python_module.code_for_node(func.body[0].target)
|
||||
assign_targets[target] = func
|
||||
end_meth = docstring_node + list(assign_targets.values()) + end_meth
|
||||
|
||||
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
||||
temp_module = cst.Module(body=[result_node])
|
||||
new_module = MetadataWrapper(temp_module)
|
||||
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods))
|
||||
new_replacement_class = new_module.visit(
|
||||
SuperTransformer(temp_module, original_methods, updated_methods, class_name)
|
||||
)
|
||||
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
||||
|
||||
return original_node.with_changes(body=new_replacement_body)
|
||||
|
||||
|
||||
class DiffConverterTransformer(CSTTransformer):
|
||||
TYPE_TO_FILE_TYPE = {
|
||||
"Config": "configuration",
|
||||
"Tokenizer": "tokenization",
|
||||
"Processor": "processor",
|
||||
"ImageProcessor": "image_processing",
|
||||
"FeatureExtractor": "feature_extractor",
|
||||
}
|
||||
|
||||
|
||||
class ModularConverterTransformer(CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
||||
|
||||
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
|
||||
@ -378,11 +523,21 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
|
||||
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
|
||||
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
|
||||
self.new_body = {} # store the new body, all global scope nodes should be added here
|
||||
self.inserted_deps = [] # nodes inserted via super dependency
|
||||
self.all_imports = [] # just stores all of the imports
|
||||
self.all_safe_imports = [] # stores the import under simple statements
|
||||
self.global_scope_index = 0
|
||||
# fmt: on
|
||||
self.files = { # mapping for different component bodies
|
||||
"modeling": {},
|
||||
"configuration": {},
|
||||
"tokenization": {},
|
||||
"processing": {},
|
||||
"image_processing": {},
|
||||
"feature_extractor": {},
|
||||
}
|
||||
self.match_patterns = "|".join(self.files.keys())
|
||||
self.all_functions = {}
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
"""When visiting imports from `transformers.models.xxx` we need to:
|
||||
@ -393,7 +548,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
import_statement = self.python_module.code_for_node(node.module)
|
||||
if m.matches(node.module, m.Attribute()):
|
||||
for imported_ in node.names:
|
||||
_import = re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", import_statement)
|
||||
_import = re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", import_statement)
|
||||
if _import:
|
||||
source = _import.groups()[0]
|
||||
if source == "modeling" and "Config" in self.python_module.code_for_node(imported_):
|
||||
@ -401,44 +556,38 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead"
|
||||
)
|
||||
if import_statement not in self.transformers_imports:
|
||||
if "models" not in import_statement:
|
||||
import_statement = "models." + import_statement
|
||||
if "transformers" not in import_statement:
|
||||
import_statement = "transformers." + import_statement
|
||||
source_code = get_module_source_from_name(import_statement)
|
||||
tree = cst.parse_module(source_code)
|
||||
self.transformers_imports[import_statement] = tree
|
||||
imported_class = self.python_module.code_for_node(imported_.name)
|
||||
self.imported_mapping[imported_class] = import_statement
|
||||
|
||||
def leave_FunctionDef(self, original_node, node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
self.global_scope_index += 100
|
||||
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
|
||||
return node
|
||||
if m.matches(node.module, m.Name()):
|
||||
if "transformers" == import_statement:
|
||||
raise ValueError(
|
||||
f"You are importing from {import_statement} directly using global imports. Import from the correct local path"
|
||||
)
|
||||
|
||||
def leave_SimpleStatementLine(self, original_node, updated_node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
||||
if parent_node not in self.all_imports:
|
||||
if updated_node not in self.all_imports:
|
||||
self.all_imports.append(updated_node)
|
||||
return updated_node
|
||||
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
|
||||
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
|
||||
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
|
||||
if re.search(
|
||||
rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement
|
||||
): # OR MATCH ..llama.modeling_llama
|
||||
return cst.RemoveFromParent()
|
||||
if parent_node not in self.all_imports:
|
||||
if updated_node not in self.all_imports:
|
||||
self.all_imports.append(updated_node)
|
||||
return updated_node
|
||||
self.global_scope_index += 100
|
||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
# TODO This only works for single target assigns!
|
||||
node_name = updated_node.body[0].targets[0].target.value
|
||||
else:
|
||||
node_name = self.python_module.code_for_node(updated_node.body[0])
|
||||
self.new_body[node_name] = {
|
||||
"insert_idx": self.global_scope_index,
|
||||
"node": updated_node,
|
||||
}
|
||||
self.config_body = [updated_node]
|
||||
return updated_node
|
||||
|
||||
def leave_ClassDef(self, original_node, updated_node):
|
||||
@ -454,6 +603,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
"""
|
||||
class_name = original_node.name.value
|
||||
bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping]
|
||||
all_bases = [k.value.value for k in original_node.bases]
|
||||
self.global_scope_index += 100
|
||||
for super_class in bases:
|
||||
if super_class not in self.imported_mapping:
|
||||
@ -469,7 +619,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
raise ValueError(
|
||||
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
|
||||
)
|
||||
|
||||
file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0]
|
||||
visited_module = self.visited_module
|
||||
if super_file_name not in visited_module: # only extract classes once
|
||||
class_finder = find_classes_in_file(
|
||||
@ -490,22 +640,47 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
|
||||
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
|
||||
start_insert_idx = self.global_scope_index
|
||||
file_to_update = self.files[file_type]
|
||||
is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n"
|
||||
for dependency, _ in list_dependencies:
|
||||
# we can write to the correct body, using the source of the parent class
|
||||
node = class_finder.global_nodes.get(dependency, None)
|
||||
if node is not None and "Config" not in class_name:
|
||||
if dependency not in self.new_body:
|
||||
if node is not None:
|
||||
if dependency not in file_to_update:
|
||||
start_insert_idx -= 1
|
||||
self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node}
|
||||
file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node}
|
||||
elif dependency not in self.inserted_deps:
|
||||
# make sure the node is written after its dependencies
|
||||
start_insert_idx = self.new_body[dependency]["insert_idx"] - 1
|
||||
start_insert_idx = file_to_update[dependency]["insert_idx"] - 1
|
||||
if (
|
||||
dependency in file_to_update.keys()
|
||||
and dependency in class_finder.first_lvl_dependency_mapping[class_name]
|
||||
):
|
||||
# If dependency is defined, but not used, raise error
|
||||
calls = m.findall(original_node, m.Call(func=m.Name(dependency)))
|
||||
if not calls and not is_empty_node and dependency not in all_bases:
|
||||
raise ValueError(
|
||||
f"""You defined `{dependency}` in the modular_{self.model_name}.py, it should be used
|
||||
when you define `{class_name}`, as it is one of it's direct dependencies. Make sure
|
||||
you use it in the `__init__` function."""
|
||||
)
|
||||
self.inserted_deps.append(dependency)
|
||||
|
||||
if len(list_dependencies) > 0:
|
||||
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
|
||||
if "Config" in class_name:
|
||||
self.config_body += [updated_node]
|
||||
else:
|
||||
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||
raise ValueError(
|
||||
f"Unable to find dependencies for {super_class} in {super_file_name}. Here are the dependencies found: {class_finder.class_dependency_mapping}. (The automatic renaming might have gone wrong!)"
|
||||
)
|
||||
|
||||
# Now, if a class was defined without parents, we look for the name
|
||||
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
|
||||
match = re.search(rf"({match_pattern})$", class_name)
|
||||
if match:
|
||||
key = TYPE_TO_FILE_TYPE[match.group(1)]
|
||||
self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||
else:
|
||||
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||
return updated_node
|
||||
|
||||
def leave_If(self, original_node, node):
|
||||
@ -513,66 +688,69 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
if m.matches(parent_node, m.Module()):
|
||||
full_statement = self.python_module.code_for_node(original_node.test)
|
||||
if re.search(r"[\s\S]*is_.*available", full_statement):
|
||||
self.all_imports.append(node)
|
||||
self.all_safe_imports.append(node)
|
||||
elif full_statement not in self.new_body:
|
||||
self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node}
|
||||
return node
|
||||
|
||||
def leave_Module(self, original_node: cst.Assign, node):
|
||||
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
|
||||
dependency_imports = {}
|
||||
config_imports = []
|
||||
for visiter in self.visited_module.values():
|
||||
dependency_imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()})
|
||||
dependency_imports = {file_type: imports.copy() for file_type in self.files}
|
||||
for super_file_name, visiter in self.visited_module.items():
|
||||
file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0]
|
||||
dependency_imports[file_type].update(
|
||||
{self.python_module.code_for_node(k): k for k in visiter.imports.values()}
|
||||
)
|
||||
|
||||
# manually clean up if it's importing a config from configuration file (ruff doesn't do that)
|
||||
config_imports = []
|
||||
for i in list(dependency_imports.values()):
|
||||
if (
|
||||
hasattr(i.body[0], "module")
|
||||
and isinstance(i.body[0].module, cst.Name)
|
||||
and f"configuration_{self.model_name}" in i.body[0].module.value
|
||||
):
|
||||
pass
|
||||
else:
|
||||
config_imports.append(i)
|
||||
|
||||
if hasattr(self, "config_body"):
|
||||
self.config_body = list(imports.values()) + config_imports + self.config_body
|
||||
dependency_imports.update(imports)
|
||||
new_body = list(dependency_imports.values())
|
||||
if len(self.new_body.keys()) > 0:
|
||||
new_body += [k[1]["node"] for k in sorted(self.new_body.items(), key=lambda x: x[1]["insert_idx"])]
|
||||
else:
|
||||
new_body = []
|
||||
return node.with_changes(body=[*new_body])
|
||||
for file, body in self.files.items():
|
||||
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
|
||||
if len(new_body) > 0:
|
||||
if file in dependency_imports.keys():
|
||||
new_body = list(dependency_imports[file].values()) + new_body
|
||||
self.files[file] = cst.Module(body=[*new_body], header=node.header)
|
||||
return node
|
||||
|
||||
|
||||
def convert_file(diff_file, old_model_name=None, new_model_name=None, cst_transformers=None):
|
||||
model_name = re.search(r"diff_(.*)(?=\.py$)", diff_file).groups()[0]
|
||||
def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None):
|
||||
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
|
||||
output = {}
|
||||
if pattern is not None:
|
||||
model_name = pattern.groups()[0]
|
||||
# Parse the Python file
|
||||
with open(diff_file, "r") as file:
|
||||
with open(modular_file, "r") as file:
|
||||
code = file.read()
|
||||
module = cst.parse_module(code)
|
||||
wrapper = MetadataWrapper(module)
|
||||
if cst_transformers is None:
|
||||
cst_transformers = DiffConverterTransformer(module, model_name, old_model_name, new_model_name)
|
||||
new_mod = wrapper.visit(cst_transformers)
|
||||
ruffed_code = run_ruff(new_mod.code, True)
|
||||
cst_transformers = ModularConverterTransformer(module, model_name, old_model_name, new_model_name)
|
||||
wrapper.visit(cst_transformers)
|
||||
for file, node in cst_transformers.files.items():
|
||||
if node != {}:
|
||||
ruffed_code = run_ruff(AUTO_GENERATED_MESSAGE + node.code, True)
|
||||
formatted_code = run_ruff(ruffed_code, False)
|
||||
if len(formatted_code.strip()) > 0:
|
||||
with open(diff_file.replace("diff_", "modeling_"), "w") as f:
|
||||
f.write(AUTO_GENERATED_MESSAGE + formatted_code)
|
||||
output[file] = [formatted_code, ruffed_code]
|
||||
return output
|
||||
else:
|
||||
print(f"modular pattern not found in {modular_file}, exiting")
|
||||
return {}
|
||||
|
||||
if hasattr(cst_transformers, "config_body"):
|
||||
config_module = cst.Module(body=[*cst_transformers.config_body], header=new_mod.header)
|
||||
with open(diff_file.replace("diff_", "configuration_"), "w") as f:
|
||||
ruffed_code = run_ruff(config_module.code, True)
|
||||
formatted_code = run_ruff(ruffed_code, False)
|
||||
f.write(AUTO_GENERATED_MESSAGE + formatted_code)
|
||||
|
||||
# TODO optimize by re-using the class_finder
|
||||
return cst_transformers
|
||||
def save_modeling_file(modular_file, converted_file):
|
||||
for file_type in converted_file.keys():
|
||||
non_comment_lines = len(
|
||||
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
|
||||
)
|
||||
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
|
||||
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f:
|
||||
f.write(converted_file[file_type][0])
|
||||
else:
|
||||
non_comment_lines = len(
|
||||
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
|
||||
)
|
||||
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
|
||||
logger.warning("The modeling code contains errors, it's written without formatting")
|
||||
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f:
|
||||
f.write(converted_file[file_type][1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -581,22 +759,24 @@ if __name__ == "__main__":
|
||||
"--files_to_parse",
|
||||
default=["all"],
|
||||
nargs="+",
|
||||
help="A list of `diff_xxxx` files that should be converted to single model file",
|
||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_model_name",
|
||||
required=False,
|
||||
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from diff-file",
|
||||
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from modular-file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--new_model_name",
|
||||
required=False,
|
||||
help="The name of the new model being added in CamelCase. If not provided is inferred from diff-file",
|
||||
help="The name of the new model being added in CamelCase. If not provided is inferred from modular-file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.files_to_parse == ["all"]:
|
||||
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
|
||||
for file_name in args.files_to_parse:
|
||||
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||
|
||||
for file_name in find_priority_list(args.files_to_parse):
|
||||
print(f"Converting {file_name} to a single model single file format")
|
||||
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
||||
converter = convert_file(file_name, args.old_model_name, args.new_model_name)
|
||||
converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
|
||||
converter = save_modeling_file(file_name, converted_files)
|
Loading…
Reference in New Issue
Block a user