mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Modular transformers
: modularity and inheritance for new model additions (#33248)
* update exampel * update * push the converted diff files for testing and ci * correct one example * fix class attributes and docstring * nits * oups * fixed config! * update * nitd * class attributes are not matched against the other, this is missing * fixed overwriting self.xxx now onto the attributes I think * partial fix, now order with docstring * fix docstring order? * more fixes * update * fix missing docstrings! * examples don't all work yet * fixup * nit * updated * hick * update * delete * update * update * update * fix * all default * no local import * fix more diff * some fix related to "safe imports" * push fixed * add helper! * style * add a check * all by default * add the * update * FINALLY! * nit * fix config dependencies * man that is it * fix fix * update diffs * fix the last issue * re-default to all * alll the fixes * nice * fix properties vs setter * fixup * updates * update dependencies * make sure to install what needs to be installed * fixup * quick fix for now * fix! * fixup * update * update * updates * whitespaces * nit * fix * simplify everything, and make it file agnostic (should work for image processors) * style * finish fixing all import issues * fixup * empty modeling should not be written! * Add logic to find who depends on what * update * cleanup * update * update gemma to support positions * some small nits * this is the correct docstring for gemma2 * fix merging of docstrings * update * fixup * update * take doc into account * styling * update * fix hidden activation * more fixes * final fixes! * fixup * fixup instruct blip video * update * fix bugs * align gemma2 with the rest as well * updats * revert * update * more reversiom * grind * more * arf * update * order will matter * finish del stuff * update * rename to modular * fixup * nits * update makefile * fixup * update order of the checks! * fix * fix docstring that has a call inside * fiix conversion check * style * add some initial documentation * update * update doc * some fixup * updates * yups * Mostly todo gimme a minut * update * fixup * revert some stuff * Review docs for the modular transformers (#33472) Docs * good update * fixup * mmm current updates lead to this code * okay, this fixes it * cool * fixes * update * nit * updates * nits * fix doc * update * revert bad changes * update * updates * proper update * update * update? * up * update * cool * nits * nits * bon bon * fix * ? * minimise changes * update * update * update * updates? * fixed gemma2 * kind of a hack * nits * update * remove `diffs` in favor of `modular` * fix make fix copies --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
parent
75b7485cc7
commit
317e069ee7
@ -137,7 +137,7 @@ jobs:
|
|||||||
parallelism: 1
|
parallelism: 1
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run: uv pip install -e .
|
- run: uv pip install -e ".[quality]"
|
||||||
- run:
|
- run:
|
||||||
name: Show installed libraries and their versions
|
name: Show installed libraries and their versions
|
||||||
command: pip freeze | tee installed.txt
|
command: pip freeze | tee installed.txt
|
||||||
@ -162,13 +162,14 @@ jobs:
|
|||||||
parallelism: 1
|
parallelism: 1
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run: uv pip install -e .
|
- run: uv pip install -e ".[quality]"
|
||||||
- run:
|
- run:
|
||||||
name: Show installed libraries and their versions
|
name: Show installed libraries and their versions
|
||||||
command: pip freeze | tee installed.txt
|
command: pip freeze | tee installed.txt
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: ~/transformers/installed.txt
|
path: ~/transformers/installed.txt
|
||||||
- run: python utils/check_copies.py
|
- run: python utils/check_copies.py
|
||||||
|
- run: python utils/check_modular_conversion.py
|
||||||
- run: python utils/check_table.py
|
- run: python utils/check_table.py
|
||||||
- run: python utils/check_dummies.py
|
- run: python utils/check_dummies.py
|
||||||
- run: python utils/check_repo.py
|
- run: python utils/check_repo.py
|
||||||
|
2
Makefile
2
Makefile
@ -36,6 +36,7 @@ autogenerate_code: deps_table_update
|
|||||||
|
|
||||||
repo-consistency:
|
repo-consistency:
|
||||||
python utils/check_copies.py
|
python utils/check_copies.py
|
||||||
|
python utils/check_modular_conversion.py
|
||||||
python utils/check_table.py
|
python utils/check_table.py
|
||||||
python utils/check_dummies.py
|
python utils/check_dummies.py
|
||||||
python utils/check_repo.py
|
python utils/check_repo.py
|
||||||
@ -80,6 +81,7 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
|
|||||||
|
|
||||||
fix-copies:
|
fix-copies:
|
||||||
python utils/check_copies.py --fix_and_overwrite
|
python utils/check_copies.py --fix_and_overwrite
|
||||||
|
python utils/check_modular_conversion.py --fix_and_overwrite
|
||||||
python utils/check_table.py --fix_and_overwrite
|
python utils/check_table.py --fix_and_overwrite
|
||||||
python utils/check_dummies.py --fix_and_overwrite
|
python utils/check_dummies.py --fix_and_overwrite
|
||||||
python utils/check_doctest_list.py --fix_and_overwrite
|
python utils/check_doctest_list.py --fix_and_overwrite
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
title: Quick tour
|
title: Quick tour
|
||||||
- local: installation
|
- local: installation
|
||||||
title: Installation
|
title: Installation
|
||||||
|
- local: add_new_model
|
||||||
|
title: Adding a new model to `transformers`
|
||||||
title: Get started
|
title: Get started
|
||||||
- sections:
|
- sections:
|
||||||
- local: pipeline_tutorial
|
- local: pipeline_tutorial
|
||||||
@ -149,6 +151,8 @@
|
|||||||
title: Interoperability with GGUF files
|
title: Interoperability with GGUF files
|
||||||
- local: tiktoken
|
- local: tiktoken
|
||||||
title: Interoperability with TikToken files
|
title: Interoperability with TikToken files
|
||||||
|
- local: modular_transformers
|
||||||
|
title: Modularity in `transformers`
|
||||||
title: Developer guides
|
title: Developer guides
|
||||||
- sections:
|
- sections:
|
||||||
- local: quantization/overview
|
- local: quantization/overview
|
||||||
|
121
docs/source/en/modular_transformers.md
Normal file
121
docs/source/en/modular_transformers.md
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
# Modular transformers
|
||||||
|
|
||||||
|
`transformers` is an opinionated framework; our philosophy is defined in the following [conceptual guide](./philosophy).
|
||||||
|
|
||||||
|
The core of that philosophy is exemplified by the [single model, single file](https://huggingface.co/blog/transformers-design-philosophy)
|
||||||
|
aspect of the library. This component's downside is that it limits the inheritance and importability of components from
|
||||||
|
files to others in the toolkit.
|
||||||
|
|
||||||
|
As a result, model components tend to be repeated across many files. There are as many attention layers defined
|
||||||
|
in `transformers` as there are models, and a significant number of those are identical to each other.
|
||||||
|
The unfortunate consequence is that independent implementations tend to diverge as fixes and changes get applied
|
||||||
|
to specific parts of the code.
|
||||||
|
|
||||||
|
In order to balance this issue, we introduced the concept of "copies" across the library. By adding a comment indicating
|
||||||
|
that code is a copy of another, we can enforce through CI and local commands that copies do not diverge. However,
|
||||||
|
while the complexity is low, this is often quite tedious to do.
|
||||||
|
|
||||||
|
And, finally, this contributes to adding a significant overhead to contributing models which we would like to remove.
|
||||||
|
This approach often requires model contributions to add modeling code (~1k lines), processor (~500 lines), tests, docs,
|
||||||
|
etc. Model contribution PRs rarely add less than 3-5k lines of code, with much of this code being boilerplate.
|
||||||
|
|
||||||
|
This raises the bar for contributions, and with Modular Transformers, we're aiming to lower the bar to a much more
|
||||||
|
acceptable point.
|
||||||
|
|
||||||
|
## What is it?
|
||||||
|
|
||||||
|
Modular Transformers introduces the concept of a "modular" file to a model folder. This modular file accepts code
|
||||||
|
that isn't typically accepted in modeling/processing files, as it allows importing from neighbouring models as well
|
||||||
|
as inheritance from classes to others.
|
||||||
|
|
||||||
|
This modular file defines models, processors, and the configuration class that would otherwise be defined in their
|
||||||
|
respective modules.
|
||||||
|
|
||||||
|
Finally, this feature introduces a new `linter` which will "unravel" the modular file into the "single model, single
|
||||||
|
file" directory structure. These files will get auto-generated every time the script is run; reducing the required
|
||||||
|
contributions to the modular file, and therefore only to the changes between the contributed model and others.
|
||||||
|
|
||||||
|
Model users will end up importing and using the single-file interface, so no change is expected here. Doing this, we
|
||||||
|
hope to combine the best of both worlds: enabling simple contributions while sticking to our philosophy.
|
||||||
|
|
||||||
|
This is therefore a replacement for the `# Copied from` markers, and previously contributed models can be expected to
|
||||||
|
be moved to the new Modular Transformers format in the coming months.
|
||||||
|
|
||||||
|
### Details
|
||||||
|
|
||||||
|
The "linter", which unravels the inheritance and creates all single-files from the modular file, will flatten the
|
||||||
|
inheritance while trying to be invisible to Python users. At this time, the linter flattens a **single** level of
|
||||||
|
inheritance.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
- If a configuration class inherits from another and adds/deletes an argument, the generated file will either directly
|
||||||
|
reference it (in case of addition) or completely remove it (in case of deletion).
|
||||||
|
- If a class inherits from another, for example: class GemmaModel(LlamaModel):, dependencies are automatically
|
||||||
|
inferred. All submodules will be automatically inferred from the superclass.
|
||||||
|
|
||||||
|
You should be able to write everything (the tokenizer, the image processor, the model, the config) in this `modular`
|
||||||
|
file, and the corresponding files will be created for you.
|
||||||
|
|
||||||
|
### Enforcement
|
||||||
|
|
||||||
|
[TODO] We are introducing a new test, that makes sure the generated content matches what is present in the `modular_xxxx.py`
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
Here is a quick example with BERT and RoBERTa. The two models are intimately related: their modeling implementation
|
||||||
|
differs solely by a change in the embedding layer.
|
||||||
|
|
||||||
|
Instead of redefining the model entirely, here is what the `modular_roberta.py` file looks like for the modeling &
|
||||||
|
configuration classes (for the sake of the example, the tokenizer is ignored at this time as very different).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from torch import nn
|
||||||
|
from ..bert.configuration_bert import BertConfig
|
||||||
|
from ..bert.modeling_bert import (
|
||||||
|
BertModel,
|
||||||
|
BertEmbeddings,
|
||||||
|
BertForMaskedLM
|
||||||
|
)
|
||||||
|
|
||||||
|
# The RoBERTa config is identical to BERT's config
|
||||||
|
class RobertaConfig(BertConfig):
|
||||||
|
model_type = 'roberta'
|
||||||
|
|
||||||
|
# We redefine the embeddings here to highlight the padding ID difference, and we redefine the position embeddings
|
||||||
|
class RobertaEmbeddings(BertEmbeddings):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config())
|
||||||
|
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.position_embeddings = nn.Embedding(
|
||||||
|
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
# The RoBERTa model is identical to the BERT model, except for the embedding layer.
|
||||||
|
# We redefine the embeddings above, so here there is no need to do additional work
|
||||||
|
class RobertaModel(BertModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.embeddings = RobertaEmbeddings(config)
|
||||||
|
|
||||||
|
|
||||||
|
# The heads now only need to redefine the model inside to the correct `RobertaModel`
|
||||||
|
class RobertaForMaskedLM(BertForMaskedLM):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = RobertaModel(config)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that if you do not use the dependency that you defined, you will have the following error:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ValueError: You defined `RobertaEmbeddings` in the modular_roberta.py, it should be used
|
||||||
|
when you define `BertModel`, as it is one of it's direct dependencies. Make sure
|
||||||
|
you use it in the `__init__` function.
|
||||||
|
```
|
||||||
|
|
||||||
|
Additionally, you may find a list of examples here:
|
||||||
|
|
||||||
|
## What it is not
|
||||||
|
|
||||||
|
It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual.
|
@ -1,20 +0,0 @@
|
|||||||
# Using the `diff_converter` linter
|
|
||||||
|
|
||||||
`pip install libcst` is a must!
|
|
||||||
|
|
||||||
# `sh examples/diff-conversion/convert_examples.sh` to get the converted outputs
|
|
||||||
|
|
||||||
The diff converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular `diff` file like `diff_gemma.py` into a `single model single file`.
|
|
||||||
|
|
||||||
Examples of possible usage are available in the `examples/diff-conversion`, or `diff_gemma` for a full model usage.
|
|
||||||
|
|
||||||
`python utils/diff_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model2.py"`
|
|
||||||
|
|
||||||
## How it works
|
|
||||||
We use the `libcst` parser to produce an AST representation of the `diff_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the difference dependencies.
|
|
||||||
|
|
||||||
The code from the `diff` file and the class dependency mapping are "merged" to produce the single model single file.
|
|
||||||
We use ruff to automatically remove the potential duplicate imports.
|
|
||||||
|
|
||||||
## Why we use libcst instead of the native AST?
|
|
||||||
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`
|
|
20
examples/modular-transformers/README.md
Normal file
20
examples/modular-transformers/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Using the `modular_converter` linter
|
||||||
|
|
||||||
|
`pip install libcst` is a must!
|
||||||
|
|
||||||
|
# `sh examples/modular-transformers/convert_examples.sh` to get the converted outputs
|
||||||
|
|
||||||
|
The modular converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular file like `modular_gemma.py` into a `single model single file`.
|
||||||
|
|
||||||
|
Examples of possible usage are available in the `examples/modular-transformers`, or `modular_gemma` for a full model usage.
|
||||||
|
|
||||||
|
`python utils/modular_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/modular-transformers/modular_my_new_model2.py"`
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
We use the `libcst` parser to produce an AST representation of the `modular_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the modularerence dependencies.
|
||||||
|
|
||||||
|
The code from the `modular` file and the class dependency mapping are "merged" to produce the single model single file.
|
||||||
|
We use ruff to automatically remove the potential duplicate imports.
|
||||||
|
|
||||||
|
## Why we use libcst instead of the native AST?
|
||||||
|
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`
|
196
examples/modular-transformers/configuration_my_new_model.py
Normal file
196
examples/modular-transformers/configuration_my_new_model.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from <path_to_diff_file.py>.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the diff. If any change should be done, please apply the change to the
|
||||||
|
# diff.py file directly. One of our CI enforces this
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...modeling_rope_utils import rope_config_validation
|
||||||
|
|
||||||
|
|
||||||
|
class MyNewModelConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`MyNewModelModel`]. It is used to instantiate an MyNewModel
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the MyNewModel-7B.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 32000):
|
||||||
|
Vocabulary size of the MyNewModel model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`MyNewModelModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||||
|
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
|
||||||
|
MyNewModel 2 up to 4096, CodeMyNewModel up to 16384.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
Padding token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
Beginning of stream token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
End of stream token id.
|
||||||
|
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||||
|
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||||
|
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
||||||
|
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
||||||
|
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to tie weight embeddings
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||||
|
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||||
|
accordingly.
|
||||||
|
Expected contents:
|
||||||
|
`rope_type` (`str`):
|
||||||
|
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||||
|
'my_new_model3'], with 'default' being the original RoPE implementation.
|
||||||
|
`factor` (`float`, *optional*):
|
||||||
|
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||||
|
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||||
|
original maximum pre-trained length.
|
||||||
|
`original_max_position_embeddings` (`int`, *optional*):
|
||||||
|
Used with 'dynamic', 'longrope' and 'my_new_model3'. The original max position embeddings used during
|
||||||
|
pretraining.
|
||||||
|
`attention_factor` (`float`, *optional*):
|
||||||
|
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||||
|
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||||
|
`factor` field to infer the suggested value.
|
||||||
|
`beta_fast` (`float`, *optional*):
|
||||||
|
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||||
|
ramp function. If unspecified, it defaults to 32.
|
||||||
|
`beta_slow` (`float`, *optional*):
|
||||||
|
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||||
|
ramp function. If unspecified, it defaults to 1.
|
||||||
|
`short_factor` (`List[float]`, *optional*):
|
||||||
|
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||||
|
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||||
|
size divided by the number of attention heads divided by 2
|
||||||
|
`long_factor` (`List[float]`, *optional*):
|
||||||
|
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||||
|
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||||
|
size divided by the number of attention heads divided by 2
|
||||||
|
`low_freq_factor` (`float`, *optional*):
|
||||||
|
Only used with 'my_new_model3'. Scaling factor applied to low frequency components of the RoPE
|
||||||
|
`high_freq_factor` (`float`, *optional*):
|
||||||
|
Only used with 'my_new_model3'. Scaling factor applied to high frequency components of the RoPE
|
||||||
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||||
|
head_dim (`int`, *optional*):
|
||||||
|
The attention head dimension. If None, it will default to hidden_size // num_heads
|
||||||
|
new_param (`int`, *optional*, defaults to `False`):
|
||||||
|
A fun new parameter
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import MyNewModelModel, MyNewModelConfig
|
||||||
|
|
||||||
|
>>> # Initializing a MyNewModel my_new_model-7b style configuration
|
||||||
|
>>> configuration = MyNewModelConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the my_new_model-7b style configuration
|
||||||
|
>>> model = MyNewModelModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "my_new_model"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
mlp_bias=True,
|
||||||
|
head_dim=None,
|
||||||
|
new_param=0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
||||||
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
|
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||||
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.mlp_bias = mlp_bias
|
||||||
|
self.new_param = new_param
|
97
examples/modular-transformers/configuration_my_new_model2.py
Normal file
97
examples/modular-transformers/configuration_my_new_model2.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from <path_to_diff_file.py>.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the diff. If any change should be done, please apply the change to the
|
||||||
|
# diff.py file directly. One of our CI enforces this
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...modeling_rope_utils import rope_config_validation
|
||||||
|
|
||||||
|
|
||||||
|
class MyNewModel2Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the Gemma-7B.
|
||||||
|
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 256000):
|
||||||
|
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`GemmaModel`]
|
||||||
|
```python
|
||||||
|
>>> from transformers import GemmaModel, GemmaConfig
|
||||||
|
>>> # Initializing a Gemma gemma-7b style configuration
|
||||||
|
>>> configuration = GemmaConfig()
|
||||||
|
>>> # Initializing a model from the gemma-7b style configuration
|
||||||
|
>>> model = GemmaModel(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "my_new_model2"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
mlp_bias=False,
|
||||||
|
head_dim=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.mlp_bias = mlp_bias
|
||||||
|
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
||||||
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
|
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||||
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
134
examples/modular-transformers/configuration_new_model.py
Normal file
134
examples/modular-transformers/configuration_new_model.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from <path_to_diff_file.py>.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the diff. If any change should be done, please apply the change to the
|
||||||
|
# diff.py file directly. One of our CI enforces this
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# Example where we only want to overwrite the defaults of an init
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class NewModelConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`NewModelModel`]. It is used to instantiate an NewModel
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the NewModel-7B.
|
||||||
|
e.g. [google/new_model-7b](https://huggingface.co/google/new_model-7b)
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 256000):
|
||||||
|
Vocabulary size of the NewModel model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`NewModelModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 3072):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
head_dim (`int`, *optional*, defaults to 256):
|
||||||
|
The attention head dimension.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||||
|
The legacy activation function. It is overwritten by the `hidden_activation`.
|
||||||
|
hidden_activation (`str` or `function`, *optional*):
|
||||||
|
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||||
|
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 0):
|
||||||
|
Padding token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
End of stream token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
Beginning of stream token id.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to tie weight embeddings
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
```python
|
||||||
|
>>> from transformers import NewModelModel, NewModelConfig
|
||||||
|
>>> # Initializing a NewModel new_model-7b style configuration
|
||||||
|
>>> configuration = NewModelConfig()
|
||||||
|
>>> # Initializing a model from the new_model-7b style configuration
|
||||||
|
>>> model = NewModelModel(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "new_model"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=256030,
|
||||||
|
hidden_size=64,
|
||||||
|
intermediate_size=90,
|
||||||
|
num_hidden_layers=28,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=16,
|
||||||
|
head_dim=256,
|
||||||
|
hidden_act="gelu_pytorch_tanh",
|
||||||
|
hidden_activation=None,
|
||||||
|
max_position_embeddings=1500,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
eos_token_id=1,
|
||||||
|
bos_token_id=2,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_activation = hidden_activation
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_heads(self):
|
||||||
|
return self.num_attention_heads
|
@ -1,7 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# Iterate over each file in the current directory
|
# Iterate over each file in the current directory
|
||||||
for file in examples/diff-conversion/diff_*; do
|
for file in examples/modular-transformers/modular_*; do
|
||||||
# Check if it's a regular file
|
# Check if it's a regular file
|
||||||
if [ -f "$file" ]; then
|
if [ -f "$file" ]; then
|
||||||
# Call the Python script with the file name as an argument
|
# Call the Python script with the file name as an argument
|
1053
examples/modular-transformers/modeling_dummy.py
Normal file
1053
examples/modular-transformers/modeling_dummy.py
Normal file
File diff suppressed because it is too large
Load Diff
1038
examples/modular-transformers/modeling_dummy_bert.py
Normal file
1038
examples/modular-transformers/modeling_dummy_bert.py
Normal file
File diff suppressed because it is too large
Load Diff
1059
examples/modular-transformers/modeling_my_new_model2.py
Normal file
1059
examples/modular-transformers/modeling_my_new_model2.py
Normal file
File diff suppressed because it is too large
Load Diff
953
examples/modular-transformers/modeling_super.py
Normal file
953
examples/modular-transformers/modeling_super.py
Normal file
@ -0,0 +1,953 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from <path_to_diff_file.py>.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the diff. If any change should be done, please apply the change to the
|
||||||
|
# diff.py file directly. One of our CI enforces this
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, StaticCache
|
||||||
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
)
|
||||||
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...utils import (
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from .configuration_super import SuperConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
class SuperRMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
SuperRMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
class SuperRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=None,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
base=10000,
|
||||||
|
device=None,
|
||||||
|
scaling_factor=1.0,
|
||||||
|
rope_type="default",
|
||||||
|
config: Optional[SuperConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# TODO (joao): remove the `if` below, only used for BC
|
||||||
|
self.rope_kwargs = {}
|
||||||
|
if config is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`SuperRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||||
|
"`config` argument. All other arguments will be removed in v4.45"
|
||||||
|
)
|
||||||
|
self.rope_kwargs = {
|
||||||
|
"rope_type": rope_type,
|
||||||
|
"factor": scaling_factor,
|
||||||
|
"dim": dim,
|
||||||
|
"base": base,
|
||||||
|
"max_position_embeddings": max_position_embeddings,
|
||||||
|
}
|
||||||
|
self.rope_type = rope_type
|
||||||
|
self.max_seq_len_cached = max_position_embeddings
|
||||||
|
self.original_max_seq_len = max_position_embeddings
|
||||||
|
else:
|
||||||
|
# BC: "rope_type" was originally "type"
|
||||||
|
if config.rope_scaling is not None:
|
||||||
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
|
else:
|
||||||
|
self.rope_type = "default"
|
||||||
|
self.max_seq_len_cached = config.max_position_embeddings
|
||||||
|
self.original_max_seq_len = config.max_position_embeddings
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
|
def _dynamic_frequency_update(self, position_ids, device):
|
||||||
|
"""
|
||||||
|
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||||
|
1 - growing beyond the cached sequence length (allow scaling)
|
||||||
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||||
|
"""
|
||||||
|
seq_len = torch.max(position_ids) + 1
|
||||||
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
|
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||||
|
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x, position_ids):
|
||||||
|
if "dynamic" in self.rope_type:
|
||||||
|
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||||
|
|
||||||
|
# Core RoPE block
|
||||||
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||||
|
device_type = x.device.type
|
||||||
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
cos = emb.cos()
|
||||||
|
sin = emb.sin()
|
||||||
|
|
||||||
|
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||||
|
cos = cos * self.attention_scaling
|
||||||
|
sin = sin * self.attention_scaling
|
||||||
|
|
||||||
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
|
Deprecated and unused.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class SuperMLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
slice = self.intermediate_size // self.config.pretraining_tp
|
||||||
|
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
||||||
|
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
||||||
|
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
||||||
|
|
||||||
|
gate_proj = torch.cat(
|
||||||
|
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
|
||||||
|
)
|
||||||
|
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
||||||
|
|
||||||
|
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
||||||
|
down_proj = [
|
||||||
|
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
||||||
|
]
|
||||||
|
down_proj = sum(down_proj)
|
||||||
|
else:
|
||||||
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class SuperAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config: SuperConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||||
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||||
|
|
||||||
|
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
|
||||||
|
self.rotary_emb = SuperRotaryEmbedding(config=self.config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split(
|
||||||
|
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||||
|
)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||||
|
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class SuperFlashAttention2(SuperAttention):
|
||||||
|
"""
|
||||||
|
Super flash attention module. This module inherits from `SuperAttention` as the weights of the module stays
|
||||||
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||||
|
flash attention and deal with padding tokens in case the input contains any of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if isinstance(past_key_value, StaticCache):
|
||||||
|
raise ValueError(
|
||||||
|
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||||
|
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||||
|
)
|
||||||
|
|
||||||
|
output_attentions = False
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
# therefore we just need to keep the original shape
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (SuperRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
position_ids=position_ids,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class SuperSdpaAttention(SuperAttention):
|
||||||
|
"""
|
||||||
|
Super attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||||
|
`SuperAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||||
|
SDPA API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Adapted from SuperAttention.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if output_attentions:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"SuperModel is using SuperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||||
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
causal_mask = attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||||
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
SUPER_ATTENTION_CLASSES = {
|
||||||
|
"eager": SuperAttention,
|
||||||
|
"flash_attention_2": SuperFlashAttention2,
|
||||||
|
"sdpa": SuperSdpaAttention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SuperDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: SuperConfig, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.self_attn = SUPER_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||||
|
|
||||||
|
self.mlp = SuperMLP(config)
|
||||||
|
self.input_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*):
|
||||||
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||||
|
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence
|
||||||
|
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||||
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||||
|
with `head_dim` being the embedding dimension of each attention head.
|
||||||
|
kwargs (`dict`, *optional*):
|
||||||
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||||
|
into the model
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
SUPER_START_DOCSTRING = r"""
|
||||||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
|
etc.)
|
||||||
|
|
||||||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||||
|
and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config ([`SuperConfig`]):
|
||||||
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||||
|
load the weights associated with the model, only the configuration. Check out the
|
||||||
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare Super Model outputting raw hidden-states without any specific head on top.",
|
||||||
|
SUPER_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class SuperPreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = SuperConfig
|
||||||
|
base_model_prefix = "model"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["SuperDecoderLayer"]
|
||||||
|
_skip_keys_device_placement = ["past_key_values"]
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
std = self.config.initializer_range
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
|
SUPER_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
|
it.
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
||||||
|
`past_key_values`).
|
||||||
|
|
||||||
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||||
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||||
|
information on the default strategy.
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
|
[What are position IDs?](../glossary#position-ids)
|
||||||
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
|
Two formats are allowed:
|
||||||
|
- a [`~cache_utils.Cache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||||
|
model's internal embedding lookup matrix.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
|
`past_key_values`).
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||||
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||||
|
the complete sequence length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare Super Model outputting raw hidden-states without any specific head on top.",
|
||||||
|
SUPER_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class SuperModel(SuperPreTrainedModel):
|
||||||
|
"""
|
||||||
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuperDecoderLayer`]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: SuperConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: SuperConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[SuperDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.rotary_emb = SuperRotaryEmbedding(config=config)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(SUPER_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
out = super().forward(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
inputs_embeds,
|
||||||
|
use_cache,
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
|
return_dict,
|
||||||
|
cache_position,
|
||||||
|
)
|
||||||
|
out.logits *= 2**4
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_length()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
@ -3,10 +3,11 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
|
|
||||||
|
from ...cache_utils import Cache
|
||||||
|
|
||||||
|
|
||||||
def _pre_process_input(input_ids):
|
def _pre_process_input(input_ids):
|
||||||
print(log(input_ids))
|
print(log(input_ids))
|
27
examples/modular-transformers/modular_dummy_bert.py
Normal file
27
examples/modular-transformers/modular_dummy_bert.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.models.bert.modeling_bert import BertModel
|
||||||
|
|
||||||
|
from ...modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
|
||||||
|
|
||||||
|
class DummyBertModel(BertModel):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||||
|
return super().forward(input_ids)
|
@ -5,10 +5,11 @@ from transformers.models.llama.configuration_llama import LlamaConfig
|
|||||||
# here there is no `ARG` so we are gonna take parent doc
|
# here there is no `ARG` so we are gonna take parent doc
|
||||||
class MyNewModelConfig(LlamaConfig):
|
class MyNewModelConfig(LlamaConfig):
|
||||||
r"""
|
r"""
|
||||||
mlp_bias (`bool`, *optional*, defaults to `False`)
|
new_param (`int`, *optional*, defaults to `False`):
|
||||||
|
A fun new parameter
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mlp_bias=True, new_param=0, **super_kwargs):
|
def __init__(self, mlp_bias=True, new_param=0, **super_kwargs):
|
||||||
|
super().__init__(self, **super_kwargs)
|
||||||
self.mlp_bias = mlp_bias
|
self.mlp_bias = mlp_bias
|
||||||
self.new_param = new_param
|
self.new_param = new_param
|
||||||
super().__init__(self, **super_kwargs)
|
|
@ -26,5 +26,10 @@ class NewModelConfig(GemmaConfig):
|
|||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
attention_bias=False,
|
attention_bias=False,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(self)
|
super().__init__(self, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_heads(self):
|
||||||
|
return self.num_attention_heads
|
20
examples/modular-transformers/modular_roberta.py
Normal file
20
examples/modular-transformers/modular_roberta.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel
|
||||||
|
|
||||||
|
|
||||||
|
class RobertaEmbeddings(BertEmbeddings):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.pad_token_id = config.pad_token_id
|
||||||
|
self.position_embeddings = nn.Embedding(
|
||||||
|
config.max_position_embeddings, config.hidden_size, config.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RobertaModel(BertModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(self, config)
|
||||||
|
# Error out here. Why? Because `RobertaEmbeddings` is defined but not used.
|
||||||
|
# no, because it's defined, and RobertaModel should use RobertaEmbedding
|
||||||
|
# here if initialized that way it won't use the new embedding.
|
@ -2,10 +2,11 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
|
|
||||||
|
from ...cache_utils import Cache
|
||||||
|
|
||||||
|
|
||||||
# example where we need some deps and some functions
|
# example where we need some deps and some functions
|
||||||
class SuperModel(LlamaModel):
|
class SuperModel(LlamaModel):
|
4
setup.py
4
setup.py
@ -192,6 +192,8 @@ _deps = [
|
|||||||
"urllib3<2.0.0",
|
"urllib3<2.0.0",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"pytest-rich",
|
"pytest-rich",
|
||||||
|
"libcst",
|
||||||
|
"rich",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -345,7 +347,7 @@ extras["testing"] = (
|
|||||||
|
|
||||||
extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"]
|
extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"]
|
||||||
extras["ruff"] = deps_list("ruff")
|
extras["ruff"] = deps_list("ruff")
|
||||||
extras["quality"] = deps_list("datasets", "isort", "ruff", "GitPython", "urllib3")
|
extras["quality"] = deps_list("datasets", "isort", "ruff", "GitPython", "urllib3", "libcst", "rich")
|
||||||
|
|
||||||
extras["all"] = (
|
extras["all"] = (
|
||||||
extras["tf"]
|
extras["tf"]
|
||||||
|
@ -97,4 +97,6 @@ deps = {
|
|||||||
"urllib3": "urllib3<2.0.0",
|
"urllib3": "urllib3<2.0.0",
|
||||||
"uvicorn": "uvicorn",
|
"uvicorn": "uvicorn",
|
||||||
"pytest-rich": "pytest-rich",
|
"pytest-rich": "pytest-rich",
|
||||||
|
"libcst": "libcst",
|
||||||
|
"rich": "rich",
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_modular_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# diff.py file directly.
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
@ -21,7 +21,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
class GemmaConfig(PretrainedConfig):
|
class GemmaConfig(PretrainedConfig):
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_modular_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# diff.py file directly.
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
@ -39,7 +39,6 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
@ -51,63 +50,6 @@ from ...utils import (
|
|||||||
from .configuration_gemma import GemmaConfig
|
from .configuration_gemma import GemmaConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
sequence_length: int,
|
|
||||||
target_length: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
min_dtype: float,
|
|
||||||
cache_position: torch.Tensor,
|
|
||||||
batch_size: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
|
||||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask (`torch.Tensor`):
|
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
|
||||||
The sequence length being processed.
|
|
||||||
target_length (`int`):
|
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
|
||||||
The dtype to use for the 4D attention mask.
|
|
||||||
device (`torch.device`):
|
|
||||||
The device to plcae the 4D attention mask on.
|
|
||||||
min_dtype (`float`):
|
|
||||||
The minimum value representable with the dtype `dtype`.
|
|
||||||
cache_position (`torch.Tensor`):
|
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
|
||||||
batch_size (`torch.Tensor`):
|
|
||||||
Batch size.
|
|
||||||
"""
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
||||||
causal_mask = attention_mask
|
|
||||||
else:
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
||||||
if sequence_length != 1:
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaRMSNorm(nn.Module):
|
class GemmaRMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -128,7 +70,7 @@ class GemmaRMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GemmaRotaryEmbedding(nn.Module):
|
class GemmaRotaryEmbedding(nn.Module):
|
||||||
@ -159,30 +101,6 @@ class GemmaRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class GemmaMLP(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.intermediate_size = config.intermediate_size
|
|
||||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
||||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
||||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
||||||
if config.hidden_activation is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
|
|
||||||
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
|
|
||||||
"`config.hidden_activation` if you want to override this behaviour.\n"
|
|
||||||
"See https://github.com/huggingface/transformers/pull/29402 for more details."
|
|
||||||
)
|
|
||||||
config.hidden_activation = "gelu_pytorch_tanh"
|
|
||||||
hidden_activation = config.hidden_activation
|
|
||||||
self.act_fn = ACT2FN[hidden_activation]
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
||||||
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
@ -212,6 +130,30 @@ class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
|||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaMLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
if config.hidden_activation is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
|
||||||
|
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
|
||||||
|
"`config.hidden_activation` if you want to override this behaviour.\n"
|
||||||
|
"See https://github.com/huggingface/transformers/pull/29402 for more details."
|
||||||
|
)
|
||||||
|
config.hidden_activation = "gelu_pytorch_tanh"
|
||||||
|
hidden_activation = config.hidden_activation
|
||||||
|
self.act_fn = ACT2FN[hidden_activation]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
@ -358,6 +300,94 @@ class GemmaAttention(nn.Module):
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaSdpaAttention(GemmaAttention):
|
||||||
|
"""
|
||||||
|
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||||
|
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||||
|
SDPA API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Adapted from GemmaAttention.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if output_attentions:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||||
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
causal_mask = attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||||
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
class GemmaFlashAttention2(GemmaAttention):
|
class GemmaFlashAttention2(GemmaAttention):
|
||||||
"""
|
"""
|
||||||
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
|
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
|
||||||
@ -458,7 +488,6 @@ class GemmaFlashAttention2(GemmaAttention):
|
|||||||
is_causal=self.is_causal,
|
is_causal=self.is_causal,
|
||||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
@ -468,92 +497,57 @@ class GemmaFlashAttention2(GemmaAttention):
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
class GemmaSdpaAttention(GemmaAttention):
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
SDPA API.
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
"""
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
# Adapted from GemmaAttention.forward
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
if output_attentions:
|
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
|
||||||
logger.warning_once(
|
|
||||||
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
|
||||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
|
||||||
)
|
|
||||||
return super().forward(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
causal_mask = attention_mask
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
padding_mask = padding_mask == 0
|
||||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
query_states = query_states.contiguous()
|
padding_mask, min_dtype
|
||||||
key_states = key_states.contiguous()
|
|
||||||
value_states = value_states.contiguous()
|
|
||||||
|
|
||||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
|
||||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
|
||||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attn_mask=causal_mask,
|
|
||||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
|
||||||
is_causal=is_causal,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
return causal_mask
|
||||||
attn_output = attn_output.view(bsz, q_len, -1)
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
GEMMA_ATTENTION_CLASSES = {
|
GEMMA_ATTENTION_CLASSES = {
|
||||||
@ -567,9 +561,7 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||||
|
|
||||||
self.mlp = GemmaMLP(config)
|
self.mlp = GemmaMLP(config)
|
||||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -830,9 +822,9 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False # noqa: F841
|
||||||
if use_cache and not isinstance(past_key_values, Cache):
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True # noqa: F841
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
past_key_values = DynamicCache()
|
past_key_values = DynamicCache()
|
||||||
else:
|
else:
|
||||||
@ -975,6 +967,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.config._attn_implementation == "sdpa"
|
self.config._attn_implementation == "sdpa"
|
||||||
and attention_mask is not None
|
and attention_mask is not None
|
||||||
@ -1149,6 +1142,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
|||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
if past_key_values:
|
if past_key_values:
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
@ -1230,7 +1224,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
|||||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
@ -21,8 +21,15 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from ...activations import ACT2FN
|
||||||
from transformers.models.llama.modeling_llama import (
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
|
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
|
from ...utils import is_torchdynamo_compiling, logging
|
||||||
|
from ..llama.modeling_llama import (
|
||||||
|
LlamaDecoderLayer,
|
||||||
LlamaFlashAttention2,
|
LlamaFlashAttention2,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaForSequenceClassification,
|
LlamaForSequenceClassification,
|
||||||
@ -32,14 +39,6 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
|
||||||
from ...generation import GenerationMixin
|
|
||||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
|
||||||
from ...modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
||||||
from ...utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -216,6 +215,35 @@ class GemmaRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
||||||
|
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
|
def forward(self, x, position_ids):
|
||||||
|
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
||||||
|
position_ids = position_ids.float() / self.scaling_factor
|
||||||
|
cos, sin = super().forward(x, position_ids)
|
||||||
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
||||||
|
"""GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
def forward(self, x, position_ids):
|
||||||
|
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
||||||
|
seq_len = torch.max(position_ids) + 1
|
||||||
|
if seq_len > self.max_position_embeddings:
|
||||||
|
base = self.base * (
|
||||||
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||||
|
) ** (self.dim / (self.dim - 2))
|
||||||
|
inv_freq = 1.0 / (
|
||||||
|
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
|
||||||
|
|
||||||
|
cos, sin = super().forward(x, position_ids)
|
||||||
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
class GemmaMLP(nn.Module):
|
class GemmaMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -340,8 +368,95 @@ class GemmaAttention(nn.Module):
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
# TODO felix: does this inheritance really work out in the end to GemmaFlashAttention2 inheriting form GemmaAttention?
|
class GemmaSdpaAttention(GemmaAttention):
|
||||||
class GemmaFlashAttention2(LlamaFlashAttention2):
|
"""
|
||||||
|
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||||
|
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||||
|
SDPA API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Adapted from GemmaAttention.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if output_attentions:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||||
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
causal_mask = attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||||
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaFlashAttention2(LlamaFlashAttention2, GemmaAttention):
|
||||||
"""
|
"""
|
||||||
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
|
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
|
||||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||||
@ -427,12 +542,12 @@ class GemmaFlashAttention2(LlamaFlashAttention2):
|
|||||||
value_states,
|
value_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
q_len,
|
q_len,
|
||||||
|
position_ids=position_ids,
|
||||||
dropout=dropout_rate,
|
dropout=dropout_rate,
|
||||||
sliding_window=getattr(self, "sliding_window", None),
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
is_causal=self.is_causal,
|
is_causal=self.is_causal,
|
||||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
@ -442,7 +557,95 @@ class GemmaFlashAttention2(LlamaFlashAttention2):
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
GEMMA_ATTENTION_CLASSES = {
|
||||||
|
"eager": GemmaAttention,
|
||||||
|
"flash_attention_2": GemmaFlashAttention2,
|
||||||
|
"sdpa": GemmaSdpaAttention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaDecoderLayer(LlamaDecoderLayer):
|
||||||
|
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||||
|
super().__init__(config)
|
||||||
|
self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||||
|
self.mlp = GemmaMLP(config)
|
||||||
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*):
|
||||||
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||||
|
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence
|
||||||
|
kwargs (`dict`, *optional*):
|
||||||
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||||
|
into the model
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class GemmaModel(LlamaModel):
|
class GemmaModel(LlamaModel):
|
||||||
|
def __init__(self, config: GemmaConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet!
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@ -455,7 +658,7 @@ class GemmaModel(LlamaModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@ -513,22 +716,72 @@ class GemmaModel(LlamaModel):
|
|||||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||||
hidden_states = hidden_states * normalizer
|
hidden_states = hidden_states * normalizer
|
||||||
|
|
||||||
return super().forward(
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
use_cache,
|
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
use_cache,
|
||||||
return_dict,
|
|
||||||
cache_position,
|
cache_position,
|
||||||
input_ids=None,
|
)
|
||||||
inputs_embeds=hidden_states,
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Example where we ony modify the docstring and call super
|
# Example where we ony modify the docstring and call super
|
||||||
class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
|
class GemmaForCausalLM(LlamaForCausalLM):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = GemmaModel(config)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@ -542,18 +795,9 @@ class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
num_logits_to_keep: int = 0,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||||
|
|
||||||
@ -589,10 +833,18 @@ class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
if labels is None and not is_torchdynamo_compiling():
|
||||||
logits = logits.float()
|
logger.warning_once(
|
||||||
|
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||||
|
)
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
# TODO: remove the float() operation in v4.46
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
@ -618,8 +870,14 @@ class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin):
|
|||||||
|
|
||||||
|
|
||||||
class GemmaForSequenceClassification(LlamaForSequenceClassification):
|
class GemmaForSequenceClassification(LlamaForSequenceClassification):
|
||||||
pass
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = GemmaModel(config)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
|
||||||
class GemmaForTokenClassification(LlamaForTokenClassification):
|
class GemmaForTokenClassification(LlamaForTokenClassification):
|
||||||
pass
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = GemmaModel(config)
|
||||||
|
self.post_init()
|
@ -1,8 +1,8 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_modular_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# diff.py file directly.
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
@ -19,7 +19,9 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
class Gemma2Config(PretrainedConfig):
|
class Gemma2Config(PretrainedConfig):
|
||||||
@ -53,7 +55,8 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
head_dim (`int`, *optional*, defaults to 256):
|
head_dim (`int`, *optional*, defaults to 256):
|
||||||
The attention head dimension.
|
The attention head dimension.
|
||||||
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||||
The non-linear activation function (function or string) in the decoder.
|
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||||
|
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||||
The maximum sequence length that this model might ever be used with.
|
The maximum sequence length that this model might ever be used with.
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
@ -77,16 +80,17 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
The dropout ratio for the attention probabilities.
|
The dropout ratio for the attention probabilities.
|
||||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
|
||||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
|
||||||
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
|
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
|
||||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
||||||
size of the sliding window.
|
size of the sliding window.
|
||||||
|
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||||
|
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import Gemma2Model, Gemma2Config
|
>>> from transformers import Gemma2Model, Gemma2Config
|
||||||
>>> # Initializing a Gemma2 gemma2-9b style configuration
|
>>> # Initializing a Gemma2 gemma2-7b style configuration
|
||||||
>>> configuration = Gemma2Config()
|
>>> configuration = Gemma2Config()
|
||||||
>>> # Initializing a model from the gemma2-9b style configuration
|
>>> # Initializing a model from the gemma2-7b style configuration
|
||||||
>>> model = Gemma2Model(configuration)
|
>>> model = Gemma2Model(configuration)
|
||||||
>>> # Accessing the model configuration
|
>>> # Accessing the model configuration
|
||||||
>>> configuration = model.config
|
>>> configuration = model.config
|
||||||
@ -94,6 +98,7 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
|
|
||||||
model_type = "gemma2"
|
model_type = "gemma2"
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
cache_implementation = "hybrid"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -116,12 +121,19 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
attention_bias=False,
|
attention_bias=False,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
final_logit_softcapping=30.0,
|
|
||||||
attn_logit_softcapping=50.0,
|
|
||||||
query_pre_attn_scalar=224,
|
query_pre_attn_scalar=224,
|
||||||
sliding_window=4096,
|
sliding_window=4096,
|
||||||
|
final_logit_softcapping=30.0,
|
||||||
|
attn_logit_softcapping=50.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -130,23 +142,14 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.num_key_value_heads = num_key_value_heads
|
self.num_key_value_heads = num_key_value_heads
|
||||||
self.hidden_activation = hidden_activation
|
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.rms_norm_eps = rms_norm_eps
|
self.rms_norm_eps = rms_norm_eps
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.attention_bias = attention_bias
|
self.attention_bias = attention_bias
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
self.attn_logit_softcapping = attn_logit_softcapping
|
self.hidden_activation = hidden_activation
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
self.final_logit_softcapping = final_logit_softcapping
|
|
||||||
self.query_pre_attn_scalar = query_pre_attn_scalar
|
self.query_pre_attn_scalar = query_pre_attn_scalar
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.cache_implementation = "hybrid"
|
self.final_logit_softcapping = final_logit_softcapping
|
||||||
|
self.attn_logit_softcapping = attn_logit_softcapping
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_modular_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# diff.py file directly.
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
@ -22,13 +22,14 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, HybridCache
|
from ...cache_utils import Cache, HybridCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@ -39,7 +40,6 @@ from ...modeling_utils import PreTrainedModel
|
|||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_2_available,
|
|
||||||
is_flash_attn_greater_or_equal,
|
is_flash_attn_greater_or_equal,
|
||||||
is_flash_attn_greater_or_equal_2_10,
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
@ -49,66 +49,6 @@ from ...utils import (
|
|||||||
from .configuration_gemma2 import Gemma2Config
|
from .configuration_gemma2 import Gemma2Config
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
|
||||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
sequence_length: int,
|
|
||||||
target_length: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
min_dtype: float,
|
|
||||||
cache_position: torch.Tensor,
|
|
||||||
batch_size: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
|
||||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask (`torch.Tensor`):
|
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
|
||||||
The sequence length being processed.
|
|
||||||
target_length (`int`):
|
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
|
||||||
The dtype to use for the 4D attention mask.
|
|
||||||
device (`torch.device`):
|
|
||||||
The device to plcae the 4D attention mask on.
|
|
||||||
min_dtype (`float`):
|
|
||||||
The minimum value representable with the dtype `dtype`.
|
|
||||||
cache_position (`torch.Tensor`):
|
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
|
||||||
batch_size (`torch.Tensor`):
|
|
||||||
Batch size.
|
|
||||||
"""
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
||||||
causal_mask = attention_mask
|
|
||||||
else:
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
||||||
if sequence_length != 1:
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma2RMSNorm(nn.Module):
|
class Gemma2RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -129,6 +69,24 @@ class Gemma2RMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma2MLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_activation]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Gemma2RotaryEmbedding(nn.Module):
|
class Gemma2RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -191,21 +149,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
class Gemma2MLP(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.intermediate_size = config.intermediate_size
|
|
||||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
||||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
||||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
||||||
self.act_fn = ACT2FN[config.hidden_activation]
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
@ -253,12 +196,12 @@ class Gemma2Attention(nn.Module):
|
|||||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||||
|
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||||
self.rotary_emb = Gemma2RotaryEmbedding(
|
self.rotary_emb = Gemma2RotaryEmbedding(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
)
|
)
|
||||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -502,9 +445,11 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
|||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
causal_mask = attention_mask
|
causal_mask = attention_mask
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||||
@ -515,6 +460,7 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
|||||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
@ -533,6 +479,59 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
|||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
GEMMA2_ATTENTION_CLASSES = {
|
GEMMA2_ATTENTION_CLASSES = {
|
||||||
"eager": Gemma2Attention,
|
"eager": Gemma2Attention,
|
||||||
"flash_attention_2": Gemma2FlashAttention2,
|
"flash_attention_2": Gemma2FlashAttention2,
|
||||||
@ -543,19 +542,16 @@ GEMMA2_ATTENTION_CLASSES = {
|
|||||||
class Gemma2DecoderLayer(nn.Module):
|
class Gemma2DecoderLayer(nn.Module):
|
||||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||||
|
|
||||||
self.mlp = Gemma2MLP(config)
|
self.mlp = Gemma2MLP(config)
|
||||||
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.config = config
|
||||||
|
|
||||||
self.is_sliding = not bool(layer_idx % 2)
|
self.is_sliding = not bool(layer_idx % 2)
|
||||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.sliding_window = config.sliding_window
|
self.sliding_window = config.sliding_window
|
||||||
|
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -567,6 +563,25 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*):
|
||||||
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||||
|
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence
|
||||||
|
kwargs (`dict`, *optional*):
|
||||||
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||||
|
into the model
|
||||||
|
"""
|
||||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||||
# Flash-attn is a 2D tensor
|
# Flash-attn is a 2D tensor
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
@ -580,6 +595,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||||
if attention_mask.shape[-1] <= 1: # when decoding
|
if attention_mask.shape[-1] <= 1: # when decoding
|
||||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@ -711,13 +727,20 @@ GEMMA2_INPUTS_DOCSTRING = r"""
|
|||||||
config.n_positions - 1]`.
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
[What are position IDs?](../glossary#position-ids)
|
||||||
past_key_values (`HybridCache`, *optional*):
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
Gemma 2 uses a unique cache class, [`HybridCache`], and does not guarantee full compatibility with other
|
Two formats are allowed:
|
||||||
cache classes.
|
- a [`~cache_utils.Cache`] instance, see our
|
||||||
|
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
@ -812,8 +835,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Instantiate an empty cache if needed.
|
if use_cache and past_key_values is None and not self.training:
|
||||||
if use_cache and past_key_values is None:
|
|
||||||
batch_size, seq_len, _ = inputs_embeds.shape
|
batch_size, seq_len, _ = inputs_embeds.shape
|
||||||
past_key_values = HybridCache(
|
past_key_values = HybridCache(
|
||||||
self.config,
|
self.config,
|
||||||
@ -828,6 +850,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
cache_position = torch.arange(
|
cache_position = torch.arange(
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
@ -844,6 +867,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||||
hidden_states = hidden_states * normalizer
|
hidden_states = hidden_states * normalizer
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
|
|
||||||
@ -880,7 +904,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
@ -1009,6 +1032,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
|||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
"What is your favorite condiment?"
|
"What is your favorite condiment?"
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
if self.training and self.config._attn_implementation != "eager":
|
if self.training and self.config._attn_implementation != "eager":
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
||||||
@ -1187,10 +1211,10 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
|
|||||||
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[HybridCache] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
@ -13,30 +13,41 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from transformers.models.gemma.configuration_gemma import GemmaConfig
|
from ...activations import ACT2FN
|
||||||
from transformers.models.gemma.modeling_gemma import (
|
from ...cache_utils import Cache, HybridCache
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
from ...utils import (
|
||||||
|
is_flash_attn_2_available,
|
||||||
|
is_flash_attn_greater_or_equal,
|
||||||
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
|
is_torchdynamo_compiling,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from ..gemma.modeling_gemma import (
|
||||||
GemmaAttention,
|
GemmaAttention,
|
||||||
GemmaDecoderLayer,
|
GemmaDecoderLayer,
|
||||||
GemmaForCausalLM,
|
GemmaForCausalLM,
|
||||||
GemmaForSequenceClassification,
|
GemmaForSequenceClassification,
|
||||||
GemmaForTokenClassification,
|
GemmaForTokenClassification,
|
||||||
GemmaModel,
|
GemmaModel,
|
||||||
|
GemmaPreTrainedModel,
|
||||||
GemmaRMSNorm,
|
GemmaRMSNorm,
|
||||||
|
_prepare_4d_causal_attention_mask_with_cache_position,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...cache_utils import Cache
|
|
||||||
from ...generation import GenerationMixin
|
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
||||||
from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
|
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||||
@ -45,33 +56,230 @@ if is_flash_attn_2_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Gemma2Config(GemmaConfig):
|
class Gemma2Config(PretrainedConfig):
|
||||||
cache_implementation = "hybrid" # TODO this is not properly ported, but cls attr is better
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the Gemma2-7B.
|
||||||
|
e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 256000):
|
||||||
|
Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`Gemma2Model`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 3072):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
head_dim (`int`, *optional*, defaults to 256):
|
||||||
|
The attention head dimension.
|
||||||
|
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||||
|
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 0):
|
||||||
|
Padding token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
End of stream token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
Beginning of stream token id.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to tie weight embeddings
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
|
||||||
|
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
||||||
|
size of the sliding window.
|
||||||
|
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||||
|
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Gemma2Model, Gemma2Config
|
||||||
|
>>> # Initializing a Gemma2 gemma2-7b style configuration
|
||||||
|
>>> configuration = Gemma2Config()
|
||||||
|
>>> # Initializing a model from the gemma2-7b style configuration
|
||||||
|
>>> model = Gemma2Model(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "gemma2"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
cache_implementation = "hybrid"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
vocab_size=256000,
|
||||||
|
hidden_size=3072,
|
||||||
|
intermediate_size=24576,
|
||||||
|
num_hidden_layers=28,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=16,
|
||||||
|
head_dim=256,
|
||||||
|
hidden_activation="gelu_pytorch_tanh",
|
||||||
|
max_position_embeddings=8192,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
eos_token_id=1,
|
||||||
|
bos_token_id=2,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
query_pre_attn_scalar=224,
|
query_pre_attn_scalar=224,
|
||||||
sliding_window=4096,
|
sliding_window=4096,
|
||||||
final_logit_softcapping=30.0,
|
final_logit_softcapping=30.0,
|
||||||
**super_kwargs,
|
attn_logit_softcapping=50.0,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(self, **super_kwargs)
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.hidden_activation = hidden_activation
|
||||||
self.query_pre_attn_scalar = query_pre_attn_scalar
|
self.query_pre_attn_scalar = query_pre_attn_scalar
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.cache_implementation = "hybrid"
|
|
||||||
self.final_logit_softcapping = final_logit_softcapping
|
self.final_logit_softcapping = final_logit_softcapping
|
||||||
|
self.attn_logit_softcapping = attn_logit_softcapping
|
||||||
|
|
||||||
|
|
||||||
class Gemma2RMSNorm(GemmaRMSNorm):
|
class Gemma2RMSNorm(GemmaRMSNorm):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma2MLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_activation]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
class Gemma2Attention(GemmaAttention):
|
class Gemma2Attention(GemmaAttention):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
|
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
|
||||||
super().__init__(config, layer_idx)
|
super().__init__(config, layer_idx)
|
||||||
self.scaling = config.query_pre_attn_scalar**-0.5
|
self.scaling = config.query_pre_attn_scalar**-0.5
|
||||||
|
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {
|
||||||
|
"sin": sin,
|
||||||
|
"cos": cos,
|
||||||
|
"sliding_window": self.sliding_window,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||||
|
|
||||||
|
if self.config.attn_logit_softcapping is not None:
|
||||||
|
attn_weights = attn_weights / self.config.attn_logit_softcapping
|
||||||
|
attn_weights = torch.tanh(attn_weights)
|
||||||
|
attn_weights = attn_weights * self.config.attn_logit_softcapping
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
class Gemma2FlashAttention2(Gemma2Attention):
|
class Gemma2FlashAttention2(Gemma2Attention):
|
||||||
@ -119,9 +327,19 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
|||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
cache_kwargs = {
|
||||||
|
"sin": sin,
|
||||||
|
"cos": cos,
|
||||||
|
"sliding_window": self.sliding_window,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
}
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
seq_len = attention_mask.shape[1]
|
||||||
|
key_states = key_states[:, :, :seq_len]
|
||||||
|
value_states = value_states[:, :, :seq_len]
|
||||||
|
|
||||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
# to be able to avoid many of these transpose/reshape/view.
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
@ -156,7 +374,6 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
|||||||
key_states = key_states.to(target_dtype)
|
key_states = key_states.to(target_dtype)
|
||||||
value_states = value_states.to(target_dtype)
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING
|
|
||||||
attn_output = _flash_attention_forward(
|
attn_output = _flash_attention_forward(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
@ -166,7 +383,9 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
|||||||
dropout=dropout_rate,
|
dropout=dropout_rate,
|
||||||
softmax_scale=self.scaling,
|
softmax_scale=self.scaling,
|
||||||
is_causal=self.is_causal,
|
is_causal=self.is_causal,
|
||||||
|
sliding_window=self.sliding_window,
|
||||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
@ -227,7 +446,12 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
|||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
cache_kwargs = {
|
||||||
|
"sin": sin,
|
||||||
|
"cos": cos,
|
||||||
|
"sliding_window": self.sliding_window,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
}
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
@ -269,8 +493,9 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
|||||||
class Gemma2DecoderLayer(GemmaDecoderLayer):
|
class Gemma2DecoderLayer(GemmaDecoderLayer):
|
||||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||||
super().__init__(config, layer_idx)
|
super().__init__(config, layer_idx)
|
||||||
|
self.config = config
|
||||||
self.is_sliding = bool(layer_idx % 2)
|
self.is_sliding = not bool(layer_idx % 2)
|
||||||
|
self.mlp = Gemma2MLP(config)
|
||||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.sliding_window = config.sliding_window
|
self.sliding_window = config.sliding_window
|
||||||
@ -286,11 +511,18 @@ class Gemma2DecoderLayer(GemmaDecoderLayer):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||||
attention_mask = attention_mask * torch.tril(
|
# Flash-attn is a 2D tensor
|
||||||
torch.ones_like(attention_mask), diagonal=(self.sliding_window - cache_position[-1])
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
)
|
if past_key_value is not None: # when decoding
|
||||||
if cache_position[0] > 0:
|
|
||||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||||
|
else:
|
||||||
|
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||||
|
sliding_window_mask = torch.tril(
|
||||||
|
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||||
|
)
|
||||||
|
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||||
|
if attention_mask.shape[-1] <= 1: # when decoding
|
||||||
|
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
@ -326,13 +558,38 @@ class Gemma2DecoderLayer(GemmaDecoderLayer):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class Gemma2Model(GemmaModel):
|
class Gemma2PreTrainedModel(GemmaPreTrainedModel):
|
||||||
|
_supports_quantized_cache = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
|
||||||
|
"""
|
||||||
|
Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
|
||||||
|
SDPA reduces the model performance on Gemma2 because of the logits softcapping.
|
||||||
|
"""
|
||||||
|
config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
|
||||||
|
|
||||||
|
# if using the default path -> swap sdpa by eager
|
||||||
|
if not hard_check_only and config._attn_implementation == "sdpa":
|
||||||
|
config._attn_implementation = "eager"
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma2Model(GemmaModel, Gemma2PreTrainedModel):
|
||||||
|
def __init__(self, config: Gemma2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
past_key_values: Optional[HybridCache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
@ -361,8 +618,21 @@ class Gemma2Model(GemmaModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if use_cache and past_key_values is None and not self.training:
|
||||||
|
batch_size, seq_len, _ = inputs_embeds.shape
|
||||||
|
past_key_values = HybridCache(
|
||||||
|
self.config,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_cache_len=seq_len,
|
||||||
|
device=self.device,
|
||||||
|
dtype=inputs_embeds.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
@ -437,50 +707,50 @@ class Gemma2Model(GemmaModel):
|
|||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
input_tensor: torch.Tensor,
|
input_tensor: torch.Tensor,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
past_key_values: Cache,
|
past_key_values: HybridCache,
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
):
|
):
|
||||||
|
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
|
||||||
|
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
|
||||||
|
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
|
||||||
|
# as it doesn't cause dynamic control issues.
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if past_key_values is not None:
|
if isinstance(past_key_values, HybridCache):
|
||||||
target_length = past_key_values.get_max_length()
|
target_length = past_key_values.get_max_length()
|
||||||
else:
|
else:
|
||||||
target_length = attention_mask.shape[-1]
|
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
||||||
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
causal_mask = attention_mask
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
else:
|
attention_mask,
|
||||||
causal_mask = torch.full(
|
sequence_length=sequence_length,
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
target_length=target_length,
|
||||||
)
|
dtype=dtype,
|
||||||
if sequence_length != 1:
|
device=device,
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
min_dtype=min_dtype,
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
cache_position=cache_position,
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
batch_size=input_tensor.shape[0],
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
)
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
|
class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = Gemma2Model(config)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
past_key_values: Optional[HybridCache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@ -488,18 +758,9 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
num_logits_to_keep: int = 0,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||||
|
|
||||||
@ -514,12 +775,17 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
|
|||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
"What is your favorite condiment?"
|
"What is your favorite condiment?"
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
|
if self.training and self.config._attn_implementation != "eager":
|
||||||
|
logger.warning_once(
|
||||||
|
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
||||||
|
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
||||||
|
)
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -535,15 +801,23 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
if labels is None and not is_torchdynamo_compiling():
|
||||||
|
logger.warning_once(
|
||||||
|
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||||
|
)
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
if self.config.final_logit_softcapping is not None:
|
if self.config.final_logit_softcapping is not None:
|
||||||
logits = logits / self.config.final_logit_softcapping
|
logits = logits / self.config.final_logit_softcapping
|
||||||
logits = torch.tanh(logits)
|
logits = torch.tanh(logits)
|
||||||
logits = logits * self.config.final_logit_softcapping
|
logits = logits * self.config.final_logit_softcapping
|
||||||
|
|
||||||
|
# TODO: remove the float() operation in v4.46
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
@ -567,10 +841,94 @@ class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
position_ids=None,
|
||||||
|
use_cache=True,
|
||||||
|
num_logits_to_keep=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
if past_key_values is not None:
|
||||||
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
|
||||||
|
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
|
||||||
|
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
|
||||||
|
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
|
||||||
|
# which retriggers a capture.
|
||||||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||||
|
else:
|
||||||
|
# The clone here is for the same reason as for `position_ids`.
|
||||||
|
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(past_key_values, HybridCache)
|
||||||
|
and attention_mask.ndim == 2
|
||||||
|
and not self.config._attn_implementation == "flash_attention_2"
|
||||||
|
):
|
||||||
|
if model_inputs["inputs_embeds"] is not None:
|
||||||
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||||
|
device = model_inputs["inputs_embeds"].device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||||
|
device = model_inputs["input_ids"].device
|
||||||
|
dtype = self.lm_head.weight.dtype
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_length(),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_logits_to_keep is not None:
|
||||||
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
|
class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
|
||||||
pass
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = Gemma2Model(config)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
|
||||||
class Gemma2ForTokenClassification(GemmaForTokenClassification):
|
class Gemma2ForTokenClassification(GemmaForTokenClassification):
|
||||||
pass
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = Gemma2Model(config)
|
||||||
|
self.post_init()
|
@ -1,8 +1,8 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_modular_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# diff.py file directly.
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
@ -24,9 +24,7 @@ from typing import Union
|
|||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
from ...utils import (
|
from ...utils import logging
|
||||||
logging,
|
|
||||||
)
|
|
||||||
from ..auto import CONFIG_MAPPING
|
from ..auto import CONFIG_MAPPING
|
||||||
|
|
||||||
|
|
||||||
@ -36,8 +34,8 @@ logger = logging.get_logger(__name__)
|
|||||||
class InstructBlipVideoVisionConfig(PretrainedConfig):
|
class InstructBlipVideoVisionConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`InstructBlipVideoVisionModel`]. It is used to
|
This is the configuration class to store the configuration of a [`InstructBlipVideoVisionModel`]. It is used to
|
||||||
instantiate a Instructblipvideo vision encoder according to the specified arguments, defining the model architecture.
|
instantiate a InstructBlipVideo vision encoder according to the specified arguments, defining the model architecture.
|
||||||
Instantiating a configuration defaults will yield a similar configuration to that of the Instructblipvideo
|
Instantiating a configuration defaults will yield a similar configuration to that of the InstructBlipVideo
|
||||||
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
|
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
@ -58,7 +56,7 @@ class InstructBlipVideoVisionConfig(PretrainedConfig):
|
|||||||
The size (resolution) of each patch.
|
The size (resolution) of each patch.
|
||||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. to 1e-5): The epsilon used by the layer
|
`"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. to 1e-5): The epsilon used by the layer
|
||||||
normalization layers.
|
normalization layers.
|
||||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
The epsilon used by the layer normalization layers.
|
The epsilon used by the layer normalization layers.
|
||||||
@ -137,9 +135,9 @@ class InstructBlipVideoVisionConfig(PretrainedConfig):
|
|||||||
class InstructBlipVideoQFormerConfig(PretrainedConfig):
|
class InstructBlipVideoQFormerConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`InstructBlipVideoQFormerModel`]. It is used to
|
This is the configuration class to store the configuration of a [`InstructBlipVideoQFormerModel`]. It is used to
|
||||||
instantiate a Instructblipvideo Querying Transformer (Q-Former) model according to the specified arguments, defining the
|
instantiate a InstructBlipVideo Querying Transformer (Q-Former) model according to the specified arguments, defining the
|
||||||
model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||||
the Instructblipvideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
|
the InstructBlipVideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
|
||||||
architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
|
architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
|
||||||
Read the documentation from [`PretrainedConfig`] for more information.
|
Read the documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
@ -189,7 +187,7 @@ class InstructBlipVideoQFormerConfig(PretrainedConfig):
|
|||||||
```python
|
```python
|
||||||
>>> from transformers import InstructBlipVideoQFormerConfig, InstructBlipVideoQFormerModel
|
>>> from transformers import InstructBlipVideoQFormerConfig, InstructBlipVideoQFormerModel
|
||||||
|
|
||||||
>>> # Initializing a Instructblipvideo Salesforce/instruct-blip-flan-t5 style configuration
|
>>> # Initializing a InstructBlipVideo Salesforce/instruct-blip-flan-t5 style configuration
|
||||||
>>> configuration = InstructBlipVideoQFormerConfig()
|
>>> configuration = InstructBlipVideoQFormerConfig()
|
||||||
|
|
||||||
>>> # Initializing a model (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
|
>>> # Initializing a model (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
|
||||||
@ -360,7 +358,7 @@ class InstructBlipVideoConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a Instructblipvideo vision model, Q-Former and
|
Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
|
||||||
language model configurations.
|
language model configurations.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_modular_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# diff.py file directly.
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
@ -19,7 +19,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Tuple, Union
|
from typing import Any, Optional, Tuple, Union
|
||||||
@ -354,6 +353,21 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
|
||||||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
|
etc.)
|
||||||
|
|
||||||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||||
|
and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||||||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
|
INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
@ -371,6 +385,71 @@ INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
|
|||||||
Whether to interpolate the pre-trained position encodings.
|
Whether to interpolate the pre-trained position encodings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
|
||||||
|
[`InstructBlipVideoProcessor.__call__`] for details.
|
||||||
|
|
||||||
|
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
||||||
|
to serve as text prompt, which the Q-Former model will encode.
|
||||||
|
|
||||||
|
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||||
|
details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
|
||||||
|
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
||||||
|
provided to serve as text prompt, which the language model can continue.
|
||||||
|
|
||||||
|
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||||
|
details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||||
|
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
||||||
|
encoder-decoder language model (like T5) is used.
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||||
|
|
||||||
|
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||||
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||||
|
be used by default.
|
||||||
|
|
||||||
|
Only relevant in case an encoder-decoder language model (like T5) is used.
|
||||||
|
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlipVideo
|
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlipVideo
|
||||||
class InstructBlipVideoEncoder(nn.Module):
|
class InstructBlipVideoEncoder(nn.Module):
|
||||||
@ -459,87 +538,6 @@ class InstructBlipVideoEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
|
|
||||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
||||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
||||||
etc.)
|
|
||||||
|
|
||||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
||||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
||||||
and behavior.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
|
|
||||||
Initializing with a config file does not load the weights associated with the model, only the
|
|
||||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
|
||||||
Args:
|
|
||||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
||||||
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
|
|
||||||
[`InstructBlipVideoProcessor.__call__`] for details.
|
|
||||||
|
|
||||||
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
|
||||||
to serve as text prompt, which the Q-Former model will encode.
|
|
||||||
|
|
||||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
|
||||||
details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
|
|
||||||
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
|
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
|
||||||
provided to serve as text prompt, which the language model can continue.
|
|
||||||
|
|
||||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
|
||||||
details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
|
|
||||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
|
|
||||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
||||||
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
|
||||||
encoder-decoder language model (like T5) is used.
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
|
||||||
|
|
||||||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
||||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
|
||||||
be used by default.
|
|
||||||
|
|
||||||
Only relevant in case an encoder-decoder language model (like T5) is used.
|
|
||||||
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to interpolate the pre-trained position encodings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlipVideo, BLIP->INSTRUCTBLIPVIDEO
|
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlipVideo, BLIP->INSTRUCTBLIPVIDEO
|
||||||
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
|
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
@ -1089,7 +1087,7 @@ class InstructBlipVideoQFormerEmbeddings(nn.Module):
|
|||||||
|
|
||||||
class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
|
class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
|
||||||
"""
|
"""
|
||||||
Querying Transformer (Q-Former), used in Instructblipvideo. Slightly modified from BLIP-2 as it also takes the
|
Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the
|
||||||
instruction as input.
|
instruction as input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -1285,7 +1283,7 @@ class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
Instructblipvideo Model for generating text given an image and an optional text prompt. The model consists of a vision
|
InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
|
||||||
encoder, Querying Transformer (Q-Former) and a language model.
|
encoder, Querying Transformer (Q-Former) and a language model.
|
||||||
|
|
||||||
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
||||||
@ -1358,7 +1356,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
hf_device_map = self.hf_device_map
|
hf_device_map = self.hf_device_map
|
||||||
|
|
||||||
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
||||||
# warn users about unexpected behavior when using multi-GPU + Instructblipvideo + `accelerate`.
|
# warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
||||||
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
||||||
@ -1505,7 +1503,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
)
|
)
|
||||||
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
@ -1584,7 +1581,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
"""
|
r"""
|
||||||
Overrides `generate` function to be able to use the model as a conditional generator.
|
Overrides `generate` function to be able to use the model as a conditional generator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -21,32 +21,18 @@ import torch.utils.checkpoint
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from transformers.models.instructblip.configuration_instructblip import (
|
from transformers.models.instructblip.configuration_instructblip import (
|
||||||
InstructBlipConfig,
|
|
||||||
InstructBlipQFormerConfig,
|
InstructBlipQFormerConfig,
|
||||||
InstructBlipVisionConfig,
|
InstructBlipVisionConfig,
|
||||||
)
|
)
|
||||||
from transformers.models.instructblip.modeling_instructblip import (
|
from transformers.models.instructblip.modeling_instructblip import (
|
||||||
InstructBlipAttention,
|
|
||||||
InstructBlipEncoder,
|
|
||||||
InstructBlipEncoderLayer,
|
|
||||||
InstructBlipForConditionalGeneration,
|
InstructBlipForConditionalGeneration,
|
||||||
InstructBlipForConditionalGenerationModelOutput,
|
InstructBlipForConditionalGenerationModelOutput,
|
||||||
InstructBlipMLP,
|
|
||||||
InstructBlipPreTrainedModel,
|
|
||||||
InstructBlipQFormerAttention,
|
|
||||||
InstructBlipQFormerEmbeddings,
|
|
||||||
InstructBlipQFormerEncoder,
|
|
||||||
InstructBlipQFormerIntermediate,
|
|
||||||
InstructBlipQFormerLayer,
|
|
||||||
InstructBlipQFormerModel,
|
|
||||||
InstructBlipQFormerOutput,
|
|
||||||
InstructBlipQFormerSelfOutput,
|
|
||||||
InstructBlipVisionEmbeddings,
|
|
||||||
InstructBlipVisionModel,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation import GenerationMixin
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..auto import CONFIG_MAPPING
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -60,8 +46,124 @@ class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoConfig(InstructBlipConfig):
|
class InstructBlipVideoConfig(PretrainedConfig):
|
||||||
pass
|
r"""
|
||||||
|
[`InstructBlipVideoConfig`] is the configuration class to store the configuration of a
|
||||||
|
[`InstructBlipVideoForConditionalGeneration`]. It is used to instantiate a Instructblipvideo model according to the specified
|
||||||
|
arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
|
||||||
|
the defaults will yield a similar configuration to that of the Instructblipvideo
|
||||||
|
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_config (`dict`, *optional*):
|
||||||
|
Dictionary of configuration options used to initialize [`InstructBlipVideoVisionConfig`].
|
||||||
|
qformer_config (`dict`, *optional*):
|
||||||
|
Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
|
||||||
|
text_config (`dict`, *optional*):
|
||||||
|
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
||||||
|
num_query_tokens (`int`, *optional*, defaults to 32):
|
||||||
|
The number of query tokens passed through the Transformer.
|
||||||
|
|
||||||
|
video_token_index (`int`, *optional*):
|
||||||
|
Token index of special video token.
|
||||||
|
kwargs (*optional*):
|
||||||
|
Dictionary of keyword arguments.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import (
|
||||||
|
... InstructBlipVideoVisionConfig,
|
||||||
|
... InstructBlipVideoQFormerConfig,
|
||||||
|
... OPTConfig,
|
||||||
|
... InstructBlipVideoConfig,
|
||||||
|
... InstructBlipVideoForConditionalGeneration,
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Initializing a InstructBlipVideoConfig with Salesforce/instruct-blip-flan-t5 style configuration
|
||||||
|
>>> configuration = InstructBlipVideoConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
|
||||||
|
>>> model = InstructBlipVideoForConditionalGeneration(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
|
||||||
|
>>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PretrainedConfig
|
||||||
|
|
||||||
|
>>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
|
||||||
|
>>> vision_config = InstructBlipVideoVisionConfig()
|
||||||
|
>>> qformer_config = InstructBlipVideoQFormerConfig()
|
||||||
|
>>> text_config = OPTConfig()
|
||||||
|
|
||||||
|
>>> config = InstructBlipVideoConfig.from_text_vision_configs(vision_config, qformer_config, text_config)
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "instructblipvideo"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config=None,
|
||||||
|
qformer_config=None,
|
||||||
|
text_config=None,
|
||||||
|
num_query_tokens=32,
|
||||||
|
video_token_index=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
if vision_config is None:
|
||||||
|
vision_config = {}
|
||||||
|
logger.info("vision_config is None. initializing the InstructBlipVideoVisionConfig with default values.")
|
||||||
|
|
||||||
|
if qformer_config is None:
|
||||||
|
qformer_config = {}
|
||||||
|
logger.info("qformer_config is None. Initializing the InstructBlipVideoQFormerConfig with default values.")
|
||||||
|
|
||||||
|
if text_config is None:
|
||||||
|
text_config = {}
|
||||||
|
logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
|
||||||
|
|
||||||
|
self.vision_config = InstructBlipVideoVisionConfig(**vision_config)
|
||||||
|
self.qformer_config = InstructBlipVideoQFormerConfig(**qformer_config)
|
||||||
|
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
|
||||||
|
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
||||||
|
|
||||||
|
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
||||||
|
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
||||||
|
|
||||||
|
self.num_query_tokens = num_query_tokens
|
||||||
|
self.video_token_index = video_token_index
|
||||||
|
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
||||||
|
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
self.initializer_factor = 1.0
|
||||||
|
self.initializer_range = 0.02
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_vision_qformer_text_configs(
|
||||||
|
cls,
|
||||||
|
vision_config: InstructBlipVideoVisionConfig,
|
||||||
|
qformer_config: InstructBlipVideoQFormerConfig,
|
||||||
|
text_config: PretrainedConfig,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
|
||||||
|
language model configurations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`InstructBlipVideoConfig`]: An instance of a configuration object
|
||||||
|
"""
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
vision_config=vision_config.to_dict(),
|
||||||
|
qformer_config=qformer_config.to_dict(),
|
||||||
|
text_config=text_config.to_dict(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -69,67 +171,7 @@ class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForCondit
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoVisionEmbeddings(InstructBlipVisionEmbeddings):
|
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoAttention(InstructBlipAttention):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoMLP(InstructBlipMLP):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoEncoderLayer(InstructBlipEncoderLayer):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoPreTrainedModel(InstructBlipPreTrainedModel):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoEncoder(InstructBlipEncoder):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoVisionModel(InstructBlipVisionModel):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoQFormerSelfOutput(InstructBlipQFormerSelfOutput):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoQFormerAttention(InstructBlipQFormerAttention):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoQFormerIntermediate(InstructBlipQFormerIntermediate):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoQFormerOutput(InstructBlipQFormerOutput):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoQFormerLayer(InstructBlipQFormerLayer):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoQFormerEncoder(InstructBlipQFormerEncoder):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoQFormerEmbeddings(InstructBlipQFormerEmbeddings):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoQFormerModel(InstructBlipQFormerModel):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration, GenerationMixin):
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
@ -146,15 +188,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
||||||
Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
|
|
||||||
1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
|
||||||
config.vocab_size]`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
|
>>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
|
||||||
>>> import torch
|
>>> import torch
|
||||||
@ -339,7 +372,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
"""
|
r"""
|
||||||
Overrides `generate` function to be able to use the model as a conditional generator.
|
Overrides `generate` function to be able to use the model as a conditional generator.
|
||||||
|
|
||||||
Args:
|
Args:
|
@ -1,8 +1,8 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_modular_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# diff.py file directly.
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
@ -20,17 +20,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from ...utils import (
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
from ..auto import CONFIG_MAPPING
|
from ..auto import CONFIG_MAPPING
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextVideoConfig(PretrainedConfig):
|
class LlavaNextVideoConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`LlavaNextVideoForConditionalGeneration`]. It is used to instantiate an
|
This is the configuration class to store the configuration of a [`LlavaNextVideoForConditionalGeneration`]. It is used to instantiate an
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_modular_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# diff.py file directly.
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
@ -19,7 +19,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
@ -130,6 +129,12 @@ def unpad_image(tensor, original_size):
|
|||||||
Returns:
|
Returns:
|
||||||
`torch.Tensor`: The unpadded image tensor.
|
`torch.Tensor`: The unpadded image tensor.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(original_size, (list, tuple)):
|
||||||
|
if not isinstance(original_size, (torch.Tensor, np.ndarray)):
|
||||||
|
raise TypeError(
|
||||||
|
f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
|
||||||
|
)
|
||||||
|
original_size = original_size.tolist()
|
||||||
original_height, original_width = original_size
|
original_height, original_width = original_size
|
||||||
current_height, current_width = tensor.shape[1:]
|
current_height, current_width = tensor.shape[1:]
|
||||||
|
|
||||||
@ -180,6 +185,7 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
|
|||||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
|
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
|
||||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||||
|
|
||||||
video_hidden_states (`torch.FloatTensor`, *optional*):
|
video_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
|
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
|
||||||
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||||
@ -191,6 +197,7 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
|
|||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
video_hidden_states: Optional[torch.FloatTensor] = None
|
video_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
@ -455,7 +462,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
|||||||
self.vocab_size = model_embeds.num_embeddings
|
self.vocab_size = model_embeds.num_embeddings
|
||||||
return model_embeds
|
return model_embeds
|
||||||
|
|
||||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration._merge_input_ids_with_image_features
|
|
||||||
def _merge_input_ids_with_image_features(
|
def _merge_input_ids_with_image_features(
|
||||||
self,
|
self,
|
||||||
image_features,
|
image_features,
|
||||||
@ -695,7 +701,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
|||||||
|
|
||||||
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
|
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
|
||||||
|
|
||||||
def pack_image_features(self, image_features, image_sizes, image_newline=None):
|
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
|
||||||
"""
|
"""
|
||||||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||||
|
|
||||||
@ -704,6 +710,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
|||||||
List of image feature tensor, each contains all the visual feature of all patches.
|
List of image feature tensor, each contains all the visual feature of all patches.
|
||||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||||
Actual image size of each images (H, W).
|
Actual image size of each images (H, W).
|
||||||
|
vision_feature_select_strategy (`str`)
|
||||||
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||||||
New line embedding vector.
|
New line embedding vector.
|
||||||
Returns:
|
Returns:
|
||||||
@ -718,8 +726,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
|||||||
base_image_feature = image_feature[0]
|
base_image_feature = image_feature[0]
|
||||||
image_feature = image_feature[1:]
|
image_feature = image_feature[1:]
|
||||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||||
if height * width != base_image_feature.shape[0]:
|
|
||||||
|
if vision_feature_select_strategy == "default":
|
||||||
|
expected_num_patches = height * width
|
||||||
|
elif vision_feature_select_strategy == "full":
|
||||||
|
expected_num_patches = height * width + 1
|
||||||
|
if expected_num_patches != base_image_feature.shape[0]:
|
||||||
raise ValueError("The number of patches is not consistent with the image size.")
|
raise ValueError("The number of patches is not consistent with the image size.")
|
||||||
|
|
||||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||||
image_sizes[image_idx],
|
image_sizes[image_idx],
|
||||||
self.config.image_grid_pinpoints,
|
self.config.image_grid_pinpoints,
|
||||||
|
@ -21,15 +21,13 @@ import torch
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
from transformers.models.llava_next.modeling_llava_next import (
|
from transformers.models.llava_next.modeling_llava_next import (
|
||||||
LlavaNextCausalLMOutputWithPast,
|
LlavaNextCausalLMOutputWithPast,
|
||||||
LlavaNextForConditionalGeneration,
|
LlavaNextForConditionalGeneration,
|
||||||
LlavaNextMultiModalProjector,
|
|
||||||
image_size_to_num_patches,
|
image_size_to_num_patches,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation import GenerationMixin
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
@ -56,18 +54,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
|||||||
The config object or dictionary of the text backbone.
|
The config object or dictionary of the text backbone.
|
||||||
ignore_index (`int`, *optional*, defaults to -100):
|
ignore_index (`int`, *optional*, defaults to -100):
|
||||||
The ignore index for the loss function.
|
The ignore index for the loss function.
|
||||||
video_token_index (`int`, *optional*, defaults to 32000):
|
|
||||||
The video token index to encode the image prompt.
|
|
||||||
image_token_index (`int`, *optional*, defaults to 32001):
|
image_token_index (`int`, *optional*, defaults to 32001):
|
||||||
The image token index to encode the image prompt.
|
The image token index to encode the image prompt.
|
||||||
spatial_pool_mode (`str`, *optional*, defaults to `"average"`):
|
|
||||||
Pooling mode to use for videos. Can be "average", "max" or "conv".
|
|
||||||
spatial_pool_stride (`int`, *optional*, defaults to 2):
|
|
||||||
Stride used in the pooling layer for videos.
|
|
||||||
image_seq_length (`int`, *optional*, defaults to 576):
|
|
||||||
Sequence length of one image embedding.
|
|
||||||
video_seq_length (`int`, *optional*, defaults to 288):
|
|
||||||
Sequence length of one video embedding.
|
|
||||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||||
The activation function used by the multimodal projector.
|
The activation function used by the multimodal projector.
|
||||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||||
@ -81,6 +69,16 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
|||||||
of the form `(height, width)`.
|
of the form `(height, width)`.
|
||||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
Whether the model's input and output word embeddings should be tied.
|
Whether the model's input and output word embeddings should be tied.
|
||||||
|
video_token_index (`int`, *optional*, defaults to 32000):
|
||||||
|
The video token index to encode the image prompt.
|
||||||
|
spatial_pool_mode (`str`, *optional*, defaults to `"average"`):
|
||||||
|
Pooling mode to use for videos. Can be "average", "max" or "conv".
|
||||||
|
spatial_pool_stride (`int`, *optional*, defaults to 2):
|
||||||
|
Stride used in the pooling layer for videos.
|
||||||
|
image_seq_length (`int`, *optional*, defaults to 576):
|
||||||
|
Sequence length of one image embedding.
|
||||||
|
video_seq_length (`int`, *optional*, defaults to 288):
|
||||||
|
Sequence length of one video embedding.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -178,7 +176,13 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LlavaNextVideoCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
|
class LlavaNextVideoCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
|
||||||
pass
|
"""
|
||||||
|
video_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
|
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
|
||||||
|
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
video_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextVideoPooler(nn.Module):
|
class LlavaNextVideoPooler(nn.Module):
|
||||||
@ -215,11 +219,7 @@ class LlavaNextVideoPooler(nn.Module):
|
|||||||
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
|
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector):
|
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration, GenerationMixin):
|
|
||||||
def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
|
def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
|
||||||
super().__init__(config, **super_kwargs)
|
super().__init__(config, **super_kwargs)
|
||||||
self.vision_resampler = LlavaNextVideoPooler(config)
|
self.vision_resampler = LlavaNextVideoPooler(config)
|
||||||
@ -287,6 +287,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
num_logits_to_keep: int = 0,
|
||||||
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
|
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -298,6 +300,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
|||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
num_logits_to_keep (`int`, *optional*):
|
||||||
|
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@ -329,7 +335,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
|||||||
... frames.append(frame)
|
... frames.append(frame)
|
||||||
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||||
|
|
||||||
>>> model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", device_map="auto)
|
>>> model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", device_map="auto")
|
||||||
>>> processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
|
>>> processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
|
||||||
|
|
||||||
>>> prompt = "USER: <video>\nWhy is this video funny? ASSISTANT:"
|
>>> prompt = "USER: <video>\nWhy is this video funny? ASSISTANT:"
|
||||||
@ -499,6 +505,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
num_logits_to_keep=num_logits_to_keep,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
@ -529,6 +536,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
|||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=image_features if pixel_values is not None else None,
|
||||||
|
video_hidden_states=video_features if pixel_values_videos is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
@ -541,6 +550,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
|||||||
image_sizes=None,
|
image_sizes=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
cache_position=None,
|
cache_position=None,
|
||||||
|
num_logits_to_keep=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
@ -560,6 +570,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration,
|
|||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
num_logits_to_keep=num_logits_to_keep,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
@ -182,6 +182,7 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
|
|||||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
|
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
|
||||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||||
|
|
||||||
video_hidden_states (`torch.FloatTensor`, *optional*):
|
video_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
|
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
|
||||||
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||||
|
76
utils/check_modular_conversion.py
Normal file
76
utils/check_modular_conversion.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import argparse
|
||||||
|
import difflib
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
|
# Console for rich printing
|
||||||
|
from modular_model_converter import convert_modular_file
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.syntax import Syntax
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig()
|
||||||
|
logging.getLogger().setLevel(logging.ERROR)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
def process_file(modular_file_path, generated_modeling_content, file_type="modeling_", fix_and_overwrite=False):
|
||||||
|
file_path = modular_file_path.replace("modular_", f"{file_type}_")
|
||||||
|
# Read the actual modeling file
|
||||||
|
with open(file_path, "r") as modeling_file:
|
||||||
|
content = modeling_file.read()
|
||||||
|
output_buffer = StringIO(generated_modeling_content[file_type][0])
|
||||||
|
output_buffer.seek(0)
|
||||||
|
output_content = output_buffer.read()
|
||||||
|
diff = difflib.unified_diff(
|
||||||
|
output_content.splitlines(),
|
||||||
|
content.splitlines(),
|
||||||
|
fromfile=f"{file_path}_generated",
|
||||||
|
tofile=f"{file_path}",
|
||||||
|
lineterm="",
|
||||||
|
)
|
||||||
|
diff_list = list(diff)
|
||||||
|
# Check for differences
|
||||||
|
if diff_list:
|
||||||
|
if fix_and_overwrite:
|
||||||
|
with open(file_path, "w") as modeling_file:
|
||||||
|
modeling_file.write(generated_modeling_content[file_type][0])
|
||||||
|
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
|
||||||
|
else:
|
||||||
|
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
|
||||||
|
diff_text = "\n".join(diff_list)
|
||||||
|
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
|
||||||
|
console.print(syntax)
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
console.print(f"[bold green]No differences found for {file_path}.[/bold green]")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def compare_files(modular_file_path, fix_and_overwrite=False):
|
||||||
|
# Generate the expected modeling content
|
||||||
|
generated_modeling_content = convert_modular_file(modular_file_path)
|
||||||
|
diff = 0
|
||||||
|
for file_type in generated_modeling_content.keys():
|
||||||
|
diff += process_file(modular_file_path, generated_modeling_content, file_type, fix_and_overwrite)
|
||||||
|
return diff
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--files", default=["all"], type=list, nargs="+", help="List of modular_xxx.py files to compare."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.files == ["all"]:
|
||||||
|
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||||
|
non_matching_files = 0
|
||||||
|
for modular_file_path in args.files:
|
||||||
|
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
||||||
|
|
||||||
|
if non_matching_files and not args.fix_and_overwrite:
|
||||||
|
raise ValueError("Some diff and their modeling code did not match.")
|
69
utils/create_dependency_mapping.py
Normal file
69
utils/create_dependency_mapping.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import ast
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
|
||||||
|
# Function to perform topological sorting
|
||||||
|
def topological_sort(dependencies):
|
||||||
|
# Create a graph and in-degree count for each node
|
||||||
|
graph = defaultdict(list)
|
||||||
|
in_degree = defaultdict(int)
|
||||||
|
|
||||||
|
# Build the graph
|
||||||
|
for node, deps in dependencies.items():
|
||||||
|
for dep in deps:
|
||||||
|
graph[dep].append(node) # node depends on dep
|
||||||
|
in_degree[node] += 1 # increase in-degree of node
|
||||||
|
|
||||||
|
# Add all nodes with zero in-degree to the queue
|
||||||
|
zero_in_degree_queue = deque([node for node in dependencies if in_degree[node] == 0])
|
||||||
|
|
||||||
|
sorted_list = []
|
||||||
|
# Perform topological sorting
|
||||||
|
while zero_in_degree_queue:
|
||||||
|
current = zero_in_degree_queue.popleft()
|
||||||
|
sorted_list.append(current)
|
||||||
|
|
||||||
|
# For each node that current points to, reduce its in-degree
|
||||||
|
for neighbor in graph[current]:
|
||||||
|
in_degree[neighbor] -= 1
|
||||||
|
if in_degree[neighbor] == 0:
|
||||||
|
zero_in_degree_queue.append(neighbor)
|
||||||
|
|
||||||
|
# Handle nodes that have no dependencies and were not initially part of the loop
|
||||||
|
for node in dependencies:
|
||||||
|
if node not in sorted_list:
|
||||||
|
sorted_list.append(node)
|
||||||
|
|
||||||
|
return sorted_list
|
||||||
|
|
||||||
|
|
||||||
|
# Function to extract class and import info from a file
|
||||||
|
def extract_classes_and_imports(file_path):
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
tree = ast.parse(file.read(), filename=file_path)
|
||||||
|
imports = set()
|
||||||
|
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||||
|
module = node.module if isinstance(node, ast.ImportFrom) else None
|
||||||
|
if module and "transformers" in module:
|
||||||
|
imports.add(module)
|
||||||
|
return imports
|
||||||
|
|
||||||
|
|
||||||
|
# Function to map dependencies between classes
|
||||||
|
def map_dependencies(py_files):
|
||||||
|
dependencies = defaultdict(set)
|
||||||
|
# First pass: Extract all classes and map to files
|
||||||
|
for file_path in py_files:
|
||||||
|
dependencies[file_path].add(None)
|
||||||
|
class_to_file = extract_classes_and_imports(file_path)
|
||||||
|
for module in class_to_file:
|
||||||
|
dependencies[file_path].add(module)
|
||||||
|
return dependencies
|
||||||
|
|
||||||
|
|
||||||
|
def find_priority_list(py_files):
|
||||||
|
dependencies = map_dependencies(py_files)
|
||||||
|
ordered_classes = topological_sort(dependencies)
|
||||||
|
return ordered_classes[::-1]
|
@ -20,21 +20,23 @@ from typing import Dict
|
|||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
from check_copies import run_ruff
|
from check_copies import run_ruff
|
||||||
|
from create_dependency_mapping import find_priority_list
|
||||||
from libcst import ClassDef, CSTTransformer, CSTVisitor
|
from libcst import ClassDef, CSTTransformer, CSTVisitor
|
||||||
from libcst import matchers as m
|
from libcst import matchers as m
|
||||||
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
|
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
|
||||||
|
|
||||||
from transformers import logging
|
from transformers import logging
|
||||||
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_modular_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# diff.py file directly.
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -82,12 +84,16 @@ class ClassFinder(CSTVisitor):
|
|||||||
self.function_def = {} # stores global scope function definition
|
self.function_def = {} # stores global scope function definition
|
||||||
self.assignments = {} # LLAMA_DOCSTRING
|
self.assignments = {} # LLAMA_DOCSTRING
|
||||||
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
|
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
|
||||||
|
self.first_lvl_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def _update_class_dependency(self, name, value):
|
def _update_class_dependency(self, name, value):
|
||||||
"""Update the dependency mapping for `name` with `value` by appending the previous
|
"""Update the dependency mapping for `name` with `value` by appending the previous
|
||||||
dependencies to the new `value`.
|
dependencies to the new `value`.
|
||||||
"""
|
"""
|
||||||
|
dep = set(self.first_lvl_dependency_mapping.get(name, set())) | set({value})
|
||||||
|
self.first_lvl_dependency_mapping[name] = dep
|
||||||
|
|
||||||
dep = set(self.class_dependency_mapping.get(value, set()))
|
dep = set(self.class_dependency_mapping.get(value, set()))
|
||||||
dep |= set(self.class_dependency_mapping.get(name, {})) | set({value})
|
dep |= set(self.class_dependency_mapping.get(name, {})) | set({value})
|
||||||
self.class_dependency_mapping[name] = dep
|
self.class_dependency_mapping[name] = dep
|
||||||
@ -146,7 +152,16 @@ class ClassFinder(CSTVisitor):
|
|||||||
def leave_Decorator(self, node):
|
def leave_Decorator(self, node):
|
||||||
if hasattr(node.decorator, "args"):
|
if hasattr(node.decorator, "args"):
|
||||||
for k in node.decorator.args:
|
for k in node.decorator.args:
|
||||||
if k.value.value in self.assignments:
|
if m.matches(k.value, m.Call(func=m.Attribute(value=m.Name()))): # and k.value.func.value.value:
|
||||||
|
if k.value.func.value.value not in self.assignments:
|
||||||
|
raise ValueError(
|
||||||
|
f"We detected a call to {k.value.func.value.value}, but it was not assigned. See the list of assigments {self.assignments.keys()}"
|
||||||
|
)
|
||||||
|
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||||
|
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
|
||||||
|
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
|
||||||
|
self._update_class_dependency(name, k.value.func.value.value)
|
||||||
|
elif m.matches(k, m.Arg(value=m.Name())) and k.value.value in self.assignments:
|
||||||
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||||
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
|
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
|
||||||
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
|
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
|
||||||
@ -178,6 +193,10 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
|||||||
self.old_name = old_name
|
self.old_name = old_name
|
||||||
self.new_name = new_name
|
self.new_name = new_name
|
||||||
self.default_name = "".join(x.title() for x in new_name.split("_"))
|
self.default_name = "".join(x.title() for x in new_name.split("_"))
|
||||||
|
if self.new_name in CONFIG_MAPPING_NAMES:
|
||||||
|
self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
|
||||||
|
"Config", ""
|
||||||
|
) # the best source of truth for class names. Could also just use the ones de
|
||||||
self.patterns = {
|
self.patterns = {
|
||||||
old_name: new_name,
|
old_name: new_name,
|
||||||
old_name.upper(): new_name.upper(),
|
old_name.upper(): new_name.upper(),
|
||||||
@ -193,7 +212,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
|||||||
|
|
||||||
def replace(match):
|
def replace(match):
|
||||||
word = match.group(0)
|
word = match.group(0)
|
||||||
return self.patterns.get(word, self.default_name)
|
result = self.patterns.get(word, self.default_name)
|
||||||
|
return result
|
||||||
|
|
||||||
return compiled_regex.sub(replace, text)
|
return compiled_regex.sub(replace, text)
|
||||||
|
|
||||||
@ -227,35 +247,102 @@ DOCSTRING_NODE = m.SimpleStatementLine(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def SUPER_CALL_NODE(func_name):
|
||||||
|
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
||||||
|
|
||||||
|
|
||||||
|
def get_docstring_indent(docstring):
|
||||||
|
# Match the first line after the opening triple quotes
|
||||||
|
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
|
||||||
|
if match:
|
||||||
|
# Return the indentation spaces captured
|
||||||
|
return len(match.group(1))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def merge_docstrings(original_docstring, updated_docstring):
|
||||||
|
# indent_level = get_docstring_indent(updated_docstring)
|
||||||
|
original_level = get_docstring_indent(original_docstring)
|
||||||
|
if " Args:\n " not in updated_docstring:
|
||||||
|
# Split the docstring at the example section, assuming `"""` is used to define the docstring
|
||||||
|
parts = original_docstring.split("```")
|
||||||
|
if "```" in updated_docstring and len(parts) > 1:
|
||||||
|
updated_docstring = updated_docstring.lstrip('r"')
|
||||||
|
new_parts = updated_docstring.split("```")
|
||||||
|
if len(new_parts) != 3:
|
||||||
|
raise ValueError("There should only be one example, and it should have opening and closing '```'")
|
||||||
|
parts[1] = new_parts[1]
|
||||||
|
updated_docstring = "".join(
|
||||||
|
[
|
||||||
|
parts[0].rstrip(" \n") + new_parts[0],
|
||||||
|
f"\n{original_level*' '}```",
|
||||||
|
parts[1],
|
||||||
|
"```",
|
||||||
|
parts[2],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif updated_docstring not in original_docstring:
|
||||||
|
# add tabulation if we are at the lowest level.
|
||||||
|
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
|
||||||
|
updated_docstring = updated_docstring.replace("\n ", "\n ")
|
||||||
|
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
|
||||||
|
return updated_docstring
|
||||||
|
|
||||||
|
|
||||||
class SuperTransformer(cst.CSTTransformer):
|
class SuperTransformer(cst.CSTTransformer):
|
||||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||||
|
|
||||||
def __init__(self, python_module: cst.Module, original_methods, updated_methods):
|
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name=""):
|
||||||
self.python_module = python_module
|
self.python_module = python_module
|
||||||
self.original_methods = original_methods
|
self.original_methods = original_methods
|
||||||
self.updated_methods = updated_methods
|
self.updated_methods = updated_methods
|
||||||
|
self.all_assign_target = {}
|
||||||
|
self.deleted_targets = {} # child node can delete some arguments
|
||||||
|
self.class_name = class_name
|
||||||
|
|
||||||
def update_body(self, existing_body, new_statements):
|
def update_body(self, existing_body, new_statements):
|
||||||
"""
|
"""
|
||||||
Helper method to update the body by removing duplicates before adding new statements.
|
Helper method to update the body by removing duplicates before adding new statements.
|
||||||
|
`existing_body` is the body of the original method, the parent class
|
||||||
|
`new_statements` are the additional statements
|
||||||
"""
|
"""
|
||||||
deduplicated_new_body = []
|
deduplicated_new_body = []
|
||||||
existing_nodes = set()
|
existing_nodes = set()
|
||||||
|
for node in new_statements:
|
||||||
|
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])):
|
||||||
|
target = self.python_module.code_for_node(node.body[0].targets[0].target)
|
||||||
|
self.all_assign_target[target] = node
|
||||||
|
if m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
|
||||||
|
target = self.python_module.code_for_node(node.body[0].target)
|
||||||
|
self.deleted_targets[target] = node
|
||||||
|
continue
|
||||||
|
|
||||||
|
for stmt in existing_body:
|
||||||
|
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
|
||||||
|
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
|
||||||
|
if target in self.deleted_targets:
|
||||||
|
logger.warning(f"Deleted the assign for {target}")
|
||||||
|
continue
|
||||||
|
if target in self.all_assign_target:
|
||||||
|
stmt = self.all_assign_target[target]
|
||||||
|
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
|
||||||
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
||||||
|
deduplicated_new_body.append(stmt)
|
||||||
|
existing_nodes.add(comment_less_code)
|
||||||
|
|
||||||
for node in new_statements:
|
for node in new_statements:
|
||||||
code = self.python_module.code_for_node(node)
|
code = self.python_module.code_for_node(node)
|
||||||
comment_less_code = re.sub(r"#.*", "", code).strip()
|
comment_less_code = re.sub(r"#.*", "", code).strip()
|
||||||
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
||||||
|
if (
|
||||||
|
node not in deduplicated_new_body
|
||||||
|
and "super().__init__" not in comment_less_code
|
||||||
|
and comment_less_code not in existing_nodes
|
||||||
|
):
|
||||||
|
if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
|
||||||
|
# HACK here to fix the pos_init() that has to be last we kinda do this.
|
||||||
|
deduplicated_new_body = deduplicated_new_body[:-1] + [node] + deduplicated_new_body[-1:]
|
||||||
existing_nodes.add(comment_less_code)
|
existing_nodes.add(comment_less_code)
|
||||||
for stmt in existing_body:
|
|
||||||
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
|
|
||||||
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
|
||||||
if comment_less_code not in existing_nodes:
|
|
||||||
if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
|
|
||||||
continue
|
|
||||||
deduplicated_new_body.append(stmt)
|
|
||||||
existing_nodes.add(stmt)
|
|
||||||
else:
|
|
||||||
logger.info(f"\nFound duplicate {self.python_module.code_for_node(stmt)}")
|
|
||||||
return deduplicated_new_body
|
return deduplicated_new_body
|
||||||
|
|
||||||
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
|
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
|
||||||
@ -263,26 +350,37 @@ class SuperTransformer(cst.CSTTransformer):
|
|||||||
to super().func_name() with the source code of the parent class' `func_name`.
|
to super().func_name() with the source code of the parent class' `func_name`.
|
||||||
It keeps everything that is defined before `super().func_name()`.
|
It keeps everything that is defined before `super().func_name()`.
|
||||||
"""
|
"""
|
||||||
new_body = []
|
|
||||||
self.has_docstring = False
|
self.has_docstring = False
|
||||||
for expr in node.body:
|
parent_has_docstring = False
|
||||||
self.has_docstring = m.matches(node.body[0], DOCSTRING_NODE)
|
if func_name in self.original_methods:
|
||||||
|
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
|
||||||
|
new_body = []
|
||||||
|
has_super_call = False
|
||||||
|
for idx, expr in enumerate(node.body):
|
||||||
if m.matches(
|
if m.matches(
|
||||||
expr,
|
expr,
|
||||||
m.SimpleStatementLine(
|
m.SimpleStatementLine(
|
||||||
body=[
|
body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]
|
||||||
m.Return(
|
|
||||||
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
|
||||||
)
|
|
||||||
| m.Expr(
|
|
||||||
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
|
||||||
)
|
|
||||||
]
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
if idx != 0 and func_name == "__init__":
|
||||||
|
raise ValueError(f"The call to super() in {self.class_name} should be at the top of the init")
|
||||||
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
|
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
|
||||||
|
has_super_call = True
|
||||||
|
elif m.matches(expr, DOCSTRING_NODE):
|
||||||
|
self.has_docstring = True
|
||||||
|
if parent_has_docstring: # actually here we ought to de-duplicate?
|
||||||
|
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
|
||||||
|
updated_docstring = expr.body[0].value.value
|
||||||
|
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||||
|
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
|
||||||
else:
|
else:
|
||||||
|
new_node = [expr]
|
||||||
|
new_body.extend(new_node)
|
||||||
|
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
|
||||||
new_body.append(expr)
|
new_body.append(expr)
|
||||||
|
if not self.has_docstring and parent_has_docstring:
|
||||||
|
new_body = [self.original_methods[func_name].body.body[0]] + new_body
|
||||||
return node.with_changes(body=new_body)
|
return node.with_changes(body=new_body)
|
||||||
|
|
||||||
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
||||||
@ -330,14 +428,22 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
|||||||
| ```
|
| ```
|
||||||
"""
|
"""
|
||||||
original_node = class_finder.classes[class_name]
|
original_node = class_finder.classes[class_name]
|
||||||
original_methods = {f.name.value if hasattr(f, "name") else f: f for f in original_node.body.body}
|
original_methods = {
|
||||||
updated_methods = {f.name.value if hasattr(f, "name") else f: f for f in updated_node.body.body}
|
f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f
|
||||||
|
for f in original_node.body.body
|
||||||
|
}
|
||||||
|
updated_methods = {
|
||||||
|
f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f
|
||||||
|
for f in updated_node.body.body
|
||||||
|
}
|
||||||
end_meth = []
|
end_meth = []
|
||||||
|
|
||||||
|
assign_targets = {}
|
||||||
|
docstring_node = []
|
||||||
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
|
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
|
||||||
for func in original_node.body.body:
|
for func in original_node.body.body:
|
||||||
name = func.name.value if hasattr(func, "name") else func
|
name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func)
|
||||||
if name in updated_methods and updated_methods[name] is not None:
|
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
|
||||||
new_params = updated_methods[name].params
|
new_params = updated_methods[name].params
|
||||||
# Replace the method in the replacement class, preserving decorators
|
# Replace the method in the replacement class, preserving decorators
|
||||||
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
|
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
|
||||||
@ -348,22 +454,61 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
|||||||
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
|
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
|
||||||
)
|
)
|
||||||
func = func.with_changes(body=updated_methods[name].body, params=new_params)
|
func = func.with_changes(body=updated_methods[name].body, params=new_params)
|
||||||
|
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||||
|
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
|
||||||
|
assign_targets[target] = func
|
||||||
|
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
|
||||||
|
target = class_finder.python_module.code_for_node(func.body[0].target)
|
||||||
|
assign_targets[target] = func
|
||||||
|
elif m.matches(func, DOCSTRING_NODE):
|
||||||
|
docstring_node = [func]
|
||||||
|
else:
|
||||||
end_meth.append(func)
|
end_meth.append(func)
|
||||||
|
|
||||||
# Port new methods that are defined only in diff-file and append at the end
|
# Port new methods that are defined only in modular-file and append at the end
|
||||||
for name, func in updated_methods.items():
|
for func in updated_node.body.body:
|
||||||
|
name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func)
|
||||||
|
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
|
||||||
|
# Extract the original docstring
|
||||||
|
updated_docstring = func.body[0].value.value
|
||||||
|
original_docstring = docstring_node[0].body[0].value.value
|
||||||
|
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||||
|
# Update the docstring in the original function
|
||||||
|
docstring_node = [
|
||||||
|
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
|
||||||
|
]
|
||||||
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
|
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
|
||||||
end_meth.append(func)
|
end_meth.append(func)
|
||||||
|
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||||
|
# TODO we only use single assign might cause issues
|
||||||
|
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
|
||||||
|
assign_targets[target] = func
|
||||||
|
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
|
||||||
|
target = class_finder.python_module.code_for_node(func.body[0].target)
|
||||||
|
assign_targets[target] = func
|
||||||
|
end_meth = docstring_node + list(assign_targets.values()) + end_meth
|
||||||
|
|
||||||
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
||||||
temp_module = cst.Module(body=[result_node])
|
temp_module = cst.Module(body=[result_node])
|
||||||
new_module = MetadataWrapper(temp_module)
|
new_module = MetadataWrapper(temp_module)
|
||||||
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods))
|
new_replacement_class = new_module.visit(
|
||||||
|
SuperTransformer(temp_module, original_methods, updated_methods, class_name)
|
||||||
|
)
|
||||||
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
||||||
|
|
||||||
return original_node.with_changes(body=new_replacement_body)
|
return original_node.with_changes(body=new_replacement_body)
|
||||||
|
|
||||||
|
|
||||||
class DiffConverterTransformer(CSTTransformer):
|
TYPE_TO_FILE_TYPE = {
|
||||||
|
"Config": "configuration",
|
||||||
|
"Tokenizer": "tokenization",
|
||||||
|
"Processor": "processor",
|
||||||
|
"ImageProcessor": "image_processing",
|
||||||
|
"FeatureExtractor": "feature_extractor",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ModularConverterTransformer(CSTTransformer):
|
||||||
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
||||||
|
|
||||||
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
|
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
|
||||||
@ -378,11 +523,21 @@ class DiffConverterTransformer(CSTTransformer):
|
|||||||
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
|
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
|
||||||
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
|
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
|
||||||
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
|
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
|
||||||
self.new_body = {} # store the new body, all global scope nodes should be added here
|
|
||||||
self.inserted_deps = [] # nodes inserted via super dependency
|
self.inserted_deps = [] # nodes inserted via super dependency
|
||||||
self.all_imports = [] # just stores all of the imports
|
self.all_imports = [] # just stores all of the imports
|
||||||
|
self.all_safe_imports = [] # stores the import under simple statements
|
||||||
self.global_scope_index = 0
|
self.global_scope_index = 0
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
self.files = { # mapping for different component bodies
|
||||||
|
"modeling": {},
|
||||||
|
"configuration": {},
|
||||||
|
"tokenization": {},
|
||||||
|
"processing": {},
|
||||||
|
"image_processing": {},
|
||||||
|
"feature_extractor": {},
|
||||||
|
}
|
||||||
|
self.match_patterns = "|".join(self.files.keys())
|
||||||
|
self.all_functions = {}
|
||||||
|
|
||||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||||
"""When visiting imports from `transformers.models.xxx` we need to:
|
"""When visiting imports from `transformers.models.xxx` we need to:
|
||||||
@ -393,7 +548,7 @@ class DiffConverterTransformer(CSTTransformer):
|
|||||||
import_statement = self.python_module.code_for_node(node.module)
|
import_statement = self.python_module.code_for_node(node.module)
|
||||||
if m.matches(node.module, m.Attribute()):
|
if m.matches(node.module, m.Attribute()):
|
||||||
for imported_ in node.names:
|
for imported_ in node.names:
|
||||||
_import = re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", import_statement)
|
_import = re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", import_statement)
|
||||||
if _import:
|
if _import:
|
||||||
source = _import.groups()[0]
|
source = _import.groups()[0]
|
||||||
if source == "modeling" and "Config" in self.python_module.code_for_node(imported_):
|
if source == "modeling" and "Config" in self.python_module.code_for_node(imported_):
|
||||||
@ -401,44 +556,38 @@ class DiffConverterTransformer(CSTTransformer):
|
|||||||
f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead"
|
f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead"
|
||||||
)
|
)
|
||||||
if import_statement not in self.transformers_imports:
|
if import_statement not in self.transformers_imports:
|
||||||
|
if "models" not in import_statement:
|
||||||
|
import_statement = "models." + import_statement
|
||||||
|
if "transformers" not in import_statement:
|
||||||
|
import_statement = "transformers." + import_statement
|
||||||
source_code = get_module_source_from_name(import_statement)
|
source_code = get_module_source_from_name(import_statement)
|
||||||
tree = cst.parse_module(source_code)
|
tree = cst.parse_module(source_code)
|
||||||
self.transformers_imports[import_statement] = tree
|
self.transformers_imports[import_statement] = tree
|
||||||
imported_class = self.python_module.code_for_node(imported_.name)
|
imported_class = self.python_module.code_for_node(imported_.name)
|
||||||
self.imported_mapping[imported_class] = import_statement
|
self.imported_mapping[imported_class] = import_statement
|
||||||
|
if m.matches(node.module, m.Name()):
|
||||||
def leave_FunctionDef(self, original_node, node):
|
if "transformers" == import_statement:
|
||||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
raise ValueError(
|
||||||
if m.matches(parent_node, m.Module()):
|
f"You are importing from {import_statement} directly using global imports. Import from the correct local path"
|
||||||
self.global_scope_index += 100
|
)
|
||||||
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
|
|
||||||
return node
|
|
||||||
|
|
||||||
def leave_SimpleStatementLine(self, original_node, updated_node):
|
def leave_SimpleStatementLine(self, original_node, updated_node):
|
||||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||||
if m.matches(parent_node, m.Module()):
|
if m.matches(parent_node, m.Module()):
|
||||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
||||||
if parent_node not in self.all_imports:
|
if updated_node not in self.all_imports:
|
||||||
self.all_imports.append(updated_node)
|
self.all_imports.append(updated_node)
|
||||||
return updated_node
|
return updated_node
|
||||||
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
|
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
|
||||||
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
|
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
|
||||||
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
|
if re.search(
|
||||||
|
rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement
|
||||||
|
): # OR MATCH ..llama.modeling_llama
|
||||||
return cst.RemoveFromParent()
|
return cst.RemoveFromParent()
|
||||||
if parent_node not in self.all_imports:
|
if updated_node not in self.all_imports:
|
||||||
self.all_imports.append(updated_node)
|
self.all_imports.append(updated_node)
|
||||||
return updated_node
|
return updated_node
|
||||||
self.global_scope_index += 100
|
self.global_scope_index += 100
|
||||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Assign()])):
|
|
||||||
# TODO This only works for single target assigns!
|
|
||||||
node_name = updated_node.body[0].targets[0].target.value
|
|
||||||
else:
|
|
||||||
node_name = self.python_module.code_for_node(updated_node.body[0])
|
|
||||||
self.new_body[node_name] = {
|
|
||||||
"insert_idx": self.global_scope_index,
|
|
||||||
"node": updated_node,
|
|
||||||
}
|
|
||||||
self.config_body = [updated_node]
|
|
||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
def leave_ClassDef(self, original_node, updated_node):
|
def leave_ClassDef(self, original_node, updated_node):
|
||||||
@ -454,6 +603,7 @@ class DiffConverterTransformer(CSTTransformer):
|
|||||||
"""
|
"""
|
||||||
class_name = original_node.name.value
|
class_name = original_node.name.value
|
||||||
bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping]
|
bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping]
|
||||||
|
all_bases = [k.value.value for k in original_node.bases]
|
||||||
self.global_scope_index += 100
|
self.global_scope_index += 100
|
||||||
for super_class in bases:
|
for super_class in bases:
|
||||||
if super_class not in self.imported_mapping:
|
if super_class not in self.imported_mapping:
|
||||||
@ -469,7 +619,7 @@ class DiffConverterTransformer(CSTTransformer):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
|
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
|
||||||
)
|
)
|
||||||
|
file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0]
|
||||||
visited_module = self.visited_module
|
visited_module = self.visited_module
|
||||||
if super_file_name not in visited_module: # only extract classes once
|
if super_file_name not in visited_module: # only extract classes once
|
||||||
class_finder = find_classes_in_file(
|
class_finder = find_classes_in_file(
|
||||||
@ -490,22 +640,47 @@ class DiffConverterTransformer(CSTTransformer):
|
|||||||
|
|
||||||
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
|
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
|
||||||
start_insert_idx = self.global_scope_index
|
start_insert_idx = self.global_scope_index
|
||||||
|
file_to_update = self.files[file_type]
|
||||||
|
is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n"
|
||||||
for dependency, _ in list_dependencies:
|
for dependency, _ in list_dependencies:
|
||||||
|
# we can write to the correct body, using the source of the parent class
|
||||||
node = class_finder.global_nodes.get(dependency, None)
|
node = class_finder.global_nodes.get(dependency, None)
|
||||||
if node is not None and "Config" not in class_name:
|
if node is not None:
|
||||||
if dependency not in self.new_body:
|
if dependency not in file_to_update:
|
||||||
start_insert_idx -= 1
|
start_insert_idx -= 1
|
||||||
self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node}
|
file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node}
|
||||||
elif dependency not in self.inserted_deps:
|
elif dependency not in self.inserted_deps:
|
||||||
# make sure the node is written after its dependencies
|
# make sure the node is written after its dependencies
|
||||||
start_insert_idx = self.new_body[dependency]["insert_idx"] - 1
|
start_insert_idx = file_to_update[dependency]["insert_idx"] - 1
|
||||||
|
if (
|
||||||
|
dependency in file_to_update.keys()
|
||||||
|
and dependency in class_finder.first_lvl_dependency_mapping[class_name]
|
||||||
|
):
|
||||||
|
# If dependency is defined, but not used, raise error
|
||||||
|
calls = m.findall(original_node, m.Call(func=m.Name(dependency)))
|
||||||
|
if not calls and not is_empty_node and dependency not in all_bases:
|
||||||
|
raise ValueError(
|
||||||
|
f"""You defined `{dependency}` in the modular_{self.model_name}.py, it should be used
|
||||||
|
when you define `{class_name}`, as it is one of it's direct dependencies. Make sure
|
||||||
|
you use it in the `__init__` function."""
|
||||||
|
)
|
||||||
self.inserted_deps.append(dependency)
|
self.inserted_deps.append(dependency)
|
||||||
|
|
||||||
if len(list_dependencies) > 0:
|
if len(list_dependencies) > 0:
|
||||||
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
|
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
|
||||||
if "Config" in class_name:
|
|
||||||
self.config_body += [updated_node]
|
|
||||||
else:
|
else:
|
||||||
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
raise ValueError(
|
||||||
|
f"Unable to find dependencies for {super_class} in {super_file_name}. Here are the dependencies found: {class_finder.class_dependency_mapping}. (The automatic renaming might have gone wrong!)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now, if a class was defined without parents, we look for the name
|
||||||
|
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
|
||||||
|
match = re.search(rf"({match_pattern})$", class_name)
|
||||||
|
if match:
|
||||||
|
key = TYPE_TO_FILE_TYPE[match.group(1)]
|
||||||
|
self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||||
|
else:
|
||||||
|
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
def leave_If(self, original_node, node):
|
def leave_If(self, original_node, node):
|
||||||
@ -513,66 +688,69 @@ class DiffConverterTransformer(CSTTransformer):
|
|||||||
if m.matches(parent_node, m.Module()):
|
if m.matches(parent_node, m.Module()):
|
||||||
full_statement = self.python_module.code_for_node(original_node.test)
|
full_statement = self.python_module.code_for_node(original_node.test)
|
||||||
if re.search(r"[\s\S]*is_.*available", full_statement):
|
if re.search(r"[\s\S]*is_.*available", full_statement):
|
||||||
self.all_imports.append(node)
|
self.all_safe_imports.append(node)
|
||||||
elif full_statement not in self.new_body:
|
elif full_statement not in self.new_body:
|
||||||
self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node}
|
self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node}
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def leave_Module(self, original_node: cst.Assign, node):
|
def leave_Module(self, original_node: cst.Assign, node):
|
||||||
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
|
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
|
||||||
dependency_imports = {}
|
dependency_imports = {file_type: imports.copy() for file_type in self.files}
|
||||||
config_imports = []
|
for super_file_name, visiter in self.visited_module.items():
|
||||||
for visiter in self.visited_module.values():
|
file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0]
|
||||||
dependency_imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()})
|
dependency_imports[file_type].update(
|
||||||
|
{self.python_module.code_for_node(k): k for k in visiter.imports.values()}
|
||||||
|
)
|
||||||
|
|
||||||
# manually clean up if it's importing a config from configuration file (ruff doesn't do that)
|
for file, body in self.files.items():
|
||||||
config_imports = []
|
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
|
||||||
for i in list(dependency_imports.values()):
|
if len(new_body) > 0:
|
||||||
if (
|
if file in dependency_imports.keys():
|
||||||
hasattr(i.body[0], "module")
|
new_body = list(dependency_imports[file].values()) + new_body
|
||||||
and isinstance(i.body[0].module, cst.Name)
|
self.files[file] = cst.Module(body=[*new_body], header=node.header)
|
||||||
and f"configuration_{self.model_name}" in i.body[0].module.value
|
return node
|
||||||
):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
config_imports.append(i)
|
|
||||||
|
|
||||||
if hasattr(self, "config_body"):
|
|
||||||
self.config_body = list(imports.values()) + config_imports + self.config_body
|
|
||||||
dependency_imports.update(imports)
|
|
||||||
new_body = list(dependency_imports.values())
|
|
||||||
if len(self.new_body.keys()) > 0:
|
|
||||||
new_body += [k[1]["node"] for k in sorted(self.new_body.items(), key=lambda x: x[1]["insert_idx"])]
|
|
||||||
else:
|
|
||||||
new_body = []
|
|
||||||
return node.with_changes(body=[*new_body])
|
|
||||||
|
|
||||||
|
|
||||||
def convert_file(diff_file, old_model_name=None, new_model_name=None, cst_transformers=None):
|
def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None):
|
||||||
model_name = re.search(r"diff_(.*)(?=\.py$)", diff_file).groups()[0]
|
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
|
||||||
|
output = {}
|
||||||
|
if pattern is not None:
|
||||||
|
model_name = pattern.groups()[0]
|
||||||
# Parse the Python file
|
# Parse the Python file
|
||||||
with open(diff_file, "r") as file:
|
with open(modular_file, "r") as file:
|
||||||
code = file.read()
|
code = file.read()
|
||||||
module = cst.parse_module(code)
|
module = cst.parse_module(code)
|
||||||
wrapper = MetadataWrapper(module)
|
wrapper = MetadataWrapper(module)
|
||||||
if cst_transformers is None:
|
if cst_transformers is None:
|
||||||
cst_transformers = DiffConverterTransformer(module, model_name, old_model_name, new_model_name)
|
cst_transformers = ModularConverterTransformer(module, model_name, old_model_name, new_model_name)
|
||||||
new_mod = wrapper.visit(cst_transformers)
|
wrapper.visit(cst_transformers)
|
||||||
ruffed_code = run_ruff(new_mod.code, True)
|
for file, node in cst_transformers.files.items():
|
||||||
|
if node != {}:
|
||||||
|
ruffed_code = run_ruff(AUTO_GENERATED_MESSAGE + node.code, True)
|
||||||
formatted_code = run_ruff(ruffed_code, False)
|
formatted_code = run_ruff(ruffed_code, False)
|
||||||
if len(formatted_code.strip()) > 0:
|
output[file] = [formatted_code, ruffed_code]
|
||||||
with open(diff_file.replace("diff_", "modeling_"), "w") as f:
|
return output
|
||||||
f.write(AUTO_GENERATED_MESSAGE + formatted_code)
|
else:
|
||||||
|
print(f"modular pattern not found in {modular_file}, exiting")
|
||||||
|
return {}
|
||||||
|
|
||||||
if hasattr(cst_transformers, "config_body"):
|
|
||||||
config_module = cst.Module(body=[*cst_transformers.config_body], header=new_mod.header)
|
|
||||||
with open(diff_file.replace("diff_", "configuration_"), "w") as f:
|
|
||||||
ruffed_code = run_ruff(config_module.code, True)
|
|
||||||
formatted_code = run_ruff(ruffed_code, False)
|
|
||||||
f.write(AUTO_GENERATED_MESSAGE + formatted_code)
|
|
||||||
|
|
||||||
# TODO optimize by re-using the class_finder
|
def save_modeling_file(modular_file, converted_file):
|
||||||
return cst_transformers
|
for file_type in converted_file.keys():
|
||||||
|
non_comment_lines = len(
|
||||||
|
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
|
||||||
|
)
|
||||||
|
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
|
||||||
|
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f:
|
||||||
|
f.write(converted_file[file_type][0])
|
||||||
|
else:
|
||||||
|
non_comment_lines = len(
|
||||||
|
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
|
||||||
|
)
|
||||||
|
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
|
||||||
|
logger.warning("The modeling code contains errors, it's written without formatting")
|
||||||
|
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f:
|
||||||
|
f.write(converted_file[file_type][1])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -581,22 +759,24 @@ if __name__ == "__main__":
|
|||||||
"--files_to_parse",
|
"--files_to_parse",
|
||||||
default=["all"],
|
default=["all"],
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help="A list of `diff_xxxx` files that should be converted to single model file",
|
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--old_model_name",
|
"--old_model_name",
|
||||||
required=False,
|
required=False,
|
||||||
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from diff-file",
|
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from modular-file",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--new_model_name",
|
"--new_model_name",
|
||||||
required=False,
|
required=False,
|
||||||
help="The name of the new model being added in CamelCase. If not provided is inferred from diff-file",
|
help="The name of the new model being added in CamelCase. If not provided is inferred from modular-file",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.files_to_parse == ["all"]:
|
if args.files_to_parse == ["all"]:
|
||||||
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
|
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||||
for file_name in args.files_to_parse:
|
|
||||||
|
for file_name in find_priority_list(args.files_to_parse):
|
||||||
print(f"Converting {file_name} to a single model single file format")
|
print(f"Converting {file_name} to a single model single file format")
|
||||||
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
||||||
converter = convert_file(file_name, args.old_model_name, args.new_model_name)
|
converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
|
||||||
|
converter = save_modeling_file(file_name, converted_files)
|
Loading…
Reference in New Issue
Block a user