mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
FlaxGPT2 (#11556)
* flax gpt2 * combine masks * handle shared embeds * add causal LM sample * style * add tests * style * fix imports, docs, quality * don't use cache * add cache * add cache 1st version * make use cache work * start adding test for generation * finish generation loop compilation * rewrite test * finish * update * update * apply sylvains suggestions * update * refactor * fix typo Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
eb3e072a3b
commit
ca33278fdb
@ -355,7 +355,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
@ -205,6 +205,13 @@ FlaxAutoModel
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForCausalLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxAutoModelForCausalLM
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -139,3 +139,17 @@ TFSequenceClassifierOutputWithPast
|
||||
|
||||
.. autoclass:: transformers.modeling_tf_outputs.TFSequenceClassifierOutputWithPast
|
||||
:members:
|
||||
|
||||
|
||||
FlaxGPT2Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxGPT2Model
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxGPT2LMHeadModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxGPT2LMHeadModel
|
||||
:members: __call__
|
||||
|
@ -1409,6 +1409,7 @@ if is_flax_available():
|
||||
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
||||
_import_structure["models.auto"].extend(
|
||||
[
|
||||
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
@ -1418,6 +1419,7 @@ if is_flax_available():
|
||||
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_MAPPING",
|
||||
"FlaxAutoModel",
|
||||
"FlaxAutoModelForCausalLM",
|
||||
"FlaxAutoModelForMaskedLM",
|
||||
"FlaxAutoModelForMultipleChoice",
|
||||
"FlaxAutoModelForNextSentencePrediction",
|
||||
@ -1452,6 +1454,7 @@ if is_flax_available():
|
||||
"FlaxElectraPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model"])
|
||||
_import_structure["models.roberta"].extend(
|
||||
[
|
||||
"FlaxRobertaForMaskedLM",
|
||||
@ -2634,6 +2637,7 @@ if TYPE_CHECKING:
|
||||
if is_flax_available():
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||
from .models.auto import (
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@ -2643,6 +2647,7 @@ if TYPE_CHECKING:
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_MAPPING,
|
||||
FlaxAutoModel,
|
||||
FlaxAutoModelForCausalLM,
|
||||
FlaxAutoModelForMaskedLM,
|
||||
FlaxAutoModelForMultipleChoice,
|
||||
FlaxAutoModelForNextSentencePrediction,
|
||||
@ -2672,6 +2677,7 @@ if TYPE_CHECKING:
|
||||
FlaxElectraModel,
|
||||
FlaxElectraPreTrainedModel,
|
||||
)
|
||||
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
|
||||
from .models.roberta import (
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
|
@ -1038,6 +1038,20 @@ FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
FLAX_CAUSAL_LM_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
>>> from transformers import {tokenizer_class}, {model_class}
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
||||
>>> outputs = model(**inputs, labels=inputs["input_ids"])
|
||||
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
FLAX_SAMPLE_DOCSTRINGS = {
|
||||
"SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
|
||||
"QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
|
||||
@ -1045,6 +1059,7 @@ FLAX_SAMPLE_DOCSTRINGS = {
|
||||
"MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
|
||||
"MaskedLM": FLAX_MASKED_LM_SAMPLE,
|
||||
"BaseModel": FLAX_BASE_MODEL_SAMPLE,
|
||||
"LMHead": FLAX_CAUSAL_LM_SAMPLE,
|
||||
}
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
|
||||
@ -46,6 +46,36 @@ class FlaxBaseModelOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxBaseModelOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs, with potential hidden states and attentions.
|
||||
|
||||
Args:
|
||||
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
past_key_values (:obj:`Dict[str, jax_xla.DeviceArray]`):
|
||||
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
||||
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
|
||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: jax_xla.DeviceArray = None
|
||||
past_key_values: Optional[Dict[str, jax_xla.DeviceArray]] = None
|
||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxBaseModelOutputWithPooling(ModelOutput):
|
||||
"""
|
||||
@ -103,6 +133,9 @@ class FlaxMaskedLMOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
FlaxCausalLMOutput = FlaxMaskedLMOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxNextSentencePredictorOutput(ModelOutput):
|
||||
"""
|
||||
|
@ -85,6 +85,7 @@ if is_tf_available():
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_auto"] = [
|
||||
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
@ -94,6 +95,7 @@ if is_flax_available():
|
||||
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_MAPPING",
|
||||
"FlaxAutoModel",
|
||||
"FlaxAutoModelForCausalLM",
|
||||
"FlaxAutoModelForMaskedLM",
|
||||
"FlaxAutoModelForMultipleChoice",
|
||||
"FlaxAutoModelForNextSentencePrediction",
|
||||
@ -167,6 +169,7 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_auto import (
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@ -176,6 +179,7 @@ if TYPE_CHECKING:
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_MAPPING,
|
||||
FlaxAutoModel,
|
||||
FlaxAutoModelForCausalLM,
|
||||
FlaxAutoModelForMaskedLM,
|
||||
FlaxAutoModelForMultipleChoice,
|
||||
FlaxAutoModelForNextSentencePrediction,
|
||||
|
@ -37,6 +37,7 @@ from ..electra.modeling_flax_electra import (
|
||||
FlaxElectraForTokenClassification,
|
||||
FlaxElectraModel,
|
||||
)
|
||||
from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
|
||||
from ..roberta.modeling_flax_roberta import (
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
@ -46,7 +47,7 @@ from ..roberta.modeling_flax_roberta import (
|
||||
FlaxRobertaModel,
|
||||
)
|
||||
from .auto_factory import auto_class_factory
|
||||
from .configuration_auto import BertConfig, ElectraConfig, RobertaConfig
|
||||
from .configuration_auto import BertConfig, ElectraConfig, GPT2Config, RobertaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -57,6 +58,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
|
||||
# Base model mapping
|
||||
(RobertaConfig, FlaxRobertaModel),
|
||||
(BertConfig, FlaxBertModel),
|
||||
(GPT2Config, FlaxGPT2Model),
|
||||
(ElectraConfig, FlaxElectraModel),
|
||||
]
|
||||
)
|
||||
@ -79,6 +81,13 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
(GPT2Config, FlaxGPT2LMHeadModel)
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
@ -123,6 +132,10 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||
|
||||
FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
|
||||
|
||||
FlaxAutoModelForCausalLM = auto_class_factory(
|
||||
"FlaxAutoModelForCausalLM", FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
|
||||
)
|
||||
|
||||
FlaxAutoModelForPreTraining = auto_class_factory(
|
||||
"FlaxAutoModelForPreTraining", FLAX_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
|
||||
)
|
||||
|
@ -18,7 +18,13 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available
|
||||
from ...file_utils import (
|
||||
_BaseLazyModule,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -51,6 +57,8 @@ if is_tf_available():
|
||||
"TFGPT2PreTrainedModel",
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
@ -81,6 +89,9 @@ if TYPE_CHECKING:
|
||||
TFGPT2PreTrainedModel,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
|
633
src/transformers/models/gpt2/modeling_flax_gpt2.py
Normal file
633
src/transformers/models/gpt2/modeling_flax_gpt2.py
Normal file
@ -0,0 +1,633 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.linen import combine_masks, dot_product_attention, make_causal_mask
|
||||
from flax.traverse_util import flatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
||||
from ...utils import logging
|
||||
from .configuration_gpt2 import GPT2Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "gpt2"
|
||||
_CONFIG_FOR_DOC = "GPT2Config"
|
||||
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
||||
|
||||
|
||||
GPT2_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||
generic methods the library implements for all its model (such as downloading or saving, resizing the input
|
||||
embeddings, pruning heads etc.)
|
||||
|
||||
This model is also a Flax Linen `flax.nn.Module
|
||||
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
|
||||
Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
|
||||
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
|
||||
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
|
||||
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
|
||||
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
"""
|
||||
|
||||
GPT2_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, input_ids_length)`):
|
||||
:obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||
config.max_position_embeddings - 1]``.
|
||||
past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):
|
||||
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
||||
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class FlaxConv1D(nn.Module):
|
||||
features: int
|
||||
use_bias: bool = True
|
||||
dtype: Any = jnp.float32
|
||||
precision: Any = None
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
inputs = jnp.asarray(inputs, self.dtype)
|
||||
kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1]))
|
||||
kernel = jnp.asarray(kernel.transpose(), self.dtype)
|
||||
y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision)
|
||||
if self.use_bias:
|
||||
bias = self.param("bias", jax.nn.initializers.zeros, (self.features,))
|
||||
bias = jnp.asarray(bias, self.dtype)
|
||||
y = y + bias
|
||||
return y
|
||||
|
||||
|
||||
class FlaxGPT2Attention(nn.Module):
|
||||
config: GPT2Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
config = self.config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.c_attn = FlaxConv1D(features=3 * self.embed_dim, dtype=self.dtype)
|
||||
self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
|
||||
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
|
||||
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
||||
|
||||
@nn.compact
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
is_initialized = self.has_variable("cache", "cached_key")
|
||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
||||
|
||||
if is_initialized:
|
||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
||||
# update key, value caches with our new 1d spatial slices
|
||||
cur_index = cache_index.value
|
||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
||||
cached_key.value = key
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
)
|
||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
||||
return key, value, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
qkv_out = self.c_attn(hidden_states)
|
||||
query, key, value = jnp.split(qkv_out, 3, axis=2)
|
||||
|
||||
query = self._split_heads(query)
|
||||
key = self._split_heads(key)
|
||||
value = self._split_heads(value)
|
||||
|
||||
query_length, key_length = query.shape[1], key.shape[1]
|
||||
|
||||
if self.has_variable("cache", "cached_key"):
|
||||
mask_shift = self.variables["cache"]["cache_index"]
|
||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
||||
causal_mask = lax.dynamic_slice(
|
||||
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
||||
)
|
||||
else:
|
||||
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
||||
|
||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
||||
attention_mask = combine_masks(attention_mask, causal_mask)
|
||||
|
||||
dropout_rng = None
|
||||
if not deterministic and self.config.attn_pdrop > 0.0:
|
||||
dropout_rng = self.make_rng("dropout")
|
||||
|
||||
# During fast autoregressive decoding, we feed one position at a time,
|
||||
# and cache the keys and values step by step.
|
||||
if self.has_variable("cache", "cached_key") or init_cache:
|
||||
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
|
||||
|
||||
# transform boolean mask into float mask
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
|
||||
)
|
||||
|
||||
# usual dot product attention
|
||||
attn_output = dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
bias=attention_bias,
|
||||
dropout_rng=dropout_rng,
|
||||
dropout_rate=self.config.attn_pdrop,
|
||||
deterministic=deterministic,
|
||||
dtype=self.dtype,
|
||||
precision=None,
|
||||
)
|
||||
|
||||
attn_output = self._merge_heads(attn_output)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
||||
|
||||
# TODO: at the moment it's not possible to retrieve attn_weights from
|
||||
# dot_product_attention, but should be in the future -> add functionality then
|
||||
|
||||
return (attn_output,)
|
||||
|
||||
|
||||
class FlaxGPT2MLP(nn.Module):
|
||||
config: GPT2Config
|
||||
intermediate_size: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
embed_dim = self.config.hidden_size
|
||||
self.c_fc = FlaxConv1D(self.intermediate_size, dtype=self.dtype)
|
||||
self.c_proj = FlaxConv1D(embed_dim, dtype=self.dtype)
|
||||
self.act = ACT2FN[self.config.activation_function]
|
||||
self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
|
||||
|
||||
def __call__(self, hidden_states, deterministic: bool = True):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxGPT2Block(nn.Module):
|
||||
config: GPT2Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
hidden_size = self.config.hidden_size
|
||||
inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
||||
self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
|
||||
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
||||
self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
outputs = self.attn(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
# residual connection
|
||||
attn_output = outputs[0]
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
|
||||
return (hidden_states,) + outputs[1:]
|
||||
|
||||
|
||||
class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = GPT2Config
|
||||
base_model_prefix = "transformer"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs,
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@property
|
||||
def _attn_layer_name(self):
|
||||
attn_layer_key_tuple = ("h", "0", "attn")
|
||||
if self.base_model_prefix in set(self.params.keys()):
|
||||
attn_layer_key_tuple = (self.base_model_prefix,) + attn_layer_key_tuple
|
||||
return attn_layer_key_tuple
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
|
||||
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (:obj:`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (:obj:`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
input_ids = jnp.ones((batch_size, max_length))
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return init_variables["cache"]
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
past_key_values: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
|
||||
if position_ids is None:
|
||||
if past_key_values is not None and input_ids.shape[-1] == 1:
|
||||
# if `past_key_values` are passed and input_ids are longer than 1, we are in cached auto-regressive generation. It has to be made sure that position_ids are set correctly
|
||||
cache_shift = flatten_dict(unfreeze(past_key_values))[self._attn_layer_name + ("cache_index",)]
|
||||
position_ids = jnp.broadcast_to(
|
||||
jnp.arange(self.config.max_position_embeddings)[None, :],
|
||||
(batch_size, self.config.max_position_embeddings),
|
||||
)
|
||||
position_ids = lax.dynamic_slice(position_ids, (0, cache_shift), (batch_size, 1))
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
||||
|
||||
if attention_mask is None:
|
||||
# if past_key_values are passed we need to create an attention_mask of the same length as `cache_length`
|
||||
if past_key_values is not None:
|
||||
cache_length = flatten_dict(unfreeze(past_key_values))[self._attn_layer_name + ("cached_key",)].shape[
|
||||
1
|
||||
]
|
||||
else:
|
||||
cache_length = sequence_length
|
||||
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. But since GPT2 uses a causal mask, those positions are masked anyways. Thus we can create a single static attention_mask here, which is more efficient for compilation
|
||||
attention_mask = jnp.ones((batch_size, cache_length))
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
inputs = {"params": params or self.params}
|
||||
|
||||
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
|
||||
if past_key_values:
|
||||
inputs["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
False,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxGPT2BlockCollection(nn.Module):
|
||||
config: GPT2Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.blocks = [
|
||||
FlaxGPT2Block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
for block in self.blocks:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = block(hidden_states, attention_mask, deterministic=deterministic, init_cache=init_cache)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
class FlaxGPT2Module(nn.Module):
|
||||
config: GPT2Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.embed_dim = self.config.hidden_size
|
||||
|
||||
self.wte = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.wpe = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
|
||||
self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
|
||||
self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
deterministic=True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
input_embeds = self.wte(input_ids.astype("i4"))
|
||||
position_embeds = self.wpe(position_ids.astype("i4"))
|
||||
|
||||
hidden_states = input_embeds + position_embeds
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
|
||||
outputs = self.h(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,) + outputs[1:]
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
GPT2_START_DOCSTRING,
|
||||
)
|
||||
class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
|
||||
module_class = FlaxGPT2Module
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxGPT2Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC
|
||||
)
|
||||
|
||||
|
||||
class FlaxGPT2LMHeadModule(nn.Module):
|
||||
config: GPT2Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.transformer = FlaxGPT2Module(self.config, dtype=self.dtype)
|
||||
self.lm_head = nn.Dense(
|
||||
self.config.vocab_size,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
|
||||
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
|
||||
else:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (lm_logits,) + outputs[1:]
|
||||
|
||||
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
embeddings).
|
||||
""",
|
||||
GPT2_START_DOCSTRING,
|
||||
)
|
||||
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
||||
module_class = FlaxGPT2LMHeadModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC
|
||||
)
|
@ -11,6 +11,9 @@ class FlaxPreTrainedModel:
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING = None
|
||||
|
||||
|
||||
@ -44,6 +47,15 @@ class FlaxAutoModel:
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxAutoModelForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxAutoModelForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
@ -248,6 +260,24 @@ class FlaxElectraPreTrainedModel:
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxGPT2LMHeadModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxGPT2Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
@ -247,12 +247,8 @@ class FlaxModelTesterMixin:
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
|
||||
return model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
).to_tuple()
|
||||
def model_jitted(input_ids, attention_mask=None, **kwargs):
|
||||
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
||||
@ -266,11 +262,11 @@ class FlaxModelTesterMixin:
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None):
|
||||
def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
|
||||
return model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# jitted function cannot return OrderedDict
|
||||
|
332
tests/test_modeling_flax_gpt2.py
Normal file
332
tests/test_modeling_flax_gpt2.py
Normal file
@ -0,0 +1,332 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import GPT2Config, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FlaxGPT2ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = None
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
self.pad_token_id = vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self, gradient_checkpointing=False):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
config = GPT2Config(
|
||||
vocab_size=self.vocab_size,
|
||||
n_embd=self.hidden_size,
|
||||
n_layer=self.num_hidden_layers,
|
||||
n_head=self.num_attention_heads,
|
||||
n_positions=self.max_position_embeddings,
|
||||
n_ctx=self.max_position_embeddings,
|
||||
use_cache=False,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
|
||||
return (config, input_ids, input_mask)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
|
||||
max_decoder_length = 20
|
||||
model = model_class_name(config)
|
||||
|
||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||
outputs_cache = model(input_ids[:, :-1], past_key_values=past_key_values)
|
||||
outputs_cache_next = model(input_ids[:, -1:], past_key_values=outputs_cache.past_key_values)
|
||||
|
||||
outputs = model(input_ids)
|
||||
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
|
||||
max_decoder_length = 20
|
||||
model = model_class_name(config)
|
||||
|
||||
attention_mask_cache = jnp.concatenate(
|
||||
[attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||
|
||||
outputs_cache = model(input_ids[:, :-1], attention_mask=attention_mask_cache, past_key_values=past_key_values)
|
||||
outputs_cache_next = model(
|
||||
input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, attention_mask=attention_mask_cache
|
||||
)
|
||||
|
||||
outputs = model(input_ids, attention_mask=attention_mask)
|
||||
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
def check_use_cache_generation(self, config, input_ids):
|
||||
prompt_length = 3
|
||||
model = FlaxGPT2LMHeadModel(config)
|
||||
max_length = 10
|
||||
batch_size = 1
|
||||
|
||||
prompt_ids = input_ids[:1, :prompt_length]
|
||||
|
||||
# put all generation logic into one function
|
||||
def generate(prompt_ids):
|
||||
def first_pass(prompt_ids):
|
||||
logits, cache = model(prompt_ids, past_key_values=past_key_values)[:2]
|
||||
next_token = jnp.argmax(logits[:, -1:], axis=-1)
|
||||
return next_token, cache
|
||||
|
||||
def greedy_search_cond_fn(state):
|
||||
cur_len, _, _, _ = state
|
||||
return ~(cur_len == max_length - 1)
|
||||
|
||||
def greedy_search_body_fn(state):
|
||||
cur_len, sequences, current_token, cache = state
|
||||
next_sequences = lax.dynamic_update_slice(sequences, current_token, (0, cur_len))
|
||||
|
||||
next_logits, next_cache = model(current_token, past_key_values=cache)[:2]
|
||||
next_token = jnp.argmax(next_logits, axis=-1)
|
||||
|
||||
return cur_len + 1, next_sequences, next_token, next_cache
|
||||
|
||||
# init tensor to be filled with generation result
|
||||
init_sequences = jnp.zeros((batch_size, max_length), dtype="i4")
|
||||
init_sequences = lax.dynamic_update_slice(init_sequences, prompt_ids, (0, 0))
|
||||
|
||||
# init past key values for cache
|
||||
past_key_values = model.init_cache(batch_size, max_length)
|
||||
|
||||
# first pass with long prompt
|
||||
next_token, cache = first_pass(prompt_ids)
|
||||
|
||||
# prepare state for generation loop
|
||||
init_state = (jnp.array(prompt_length), init_sequences, next_token, cache)
|
||||
|
||||
# fast generation
|
||||
_, output_sequences, final_token, _ = lax.while_loop(
|
||||
greedy_search_cond_fn, greedy_search_body_fn, init_state
|
||||
)
|
||||
|
||||
# append last token
|
||||
output_sequences = lax.dynamic_update_slice(output_sequences, final_token, (0, max_length - 1))
|
||||
|
||||
return output_sequences
|
||||
|
||||
jit_generate = jax.jit(generate)
|
||||
output_sequences = jit_generate(prompt_ids)
|
||||
self.parent.assertEqual(output_sequences.shape, (1, max_length))
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxGPT2ModelTester(self)
|
||||
|
||||
def test_use_cache_forward(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)
|
||||
|
||||
def test_use_cache_forward_with_attn_mask(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_use_cache_forward_with_attn_mask(
|
||||
model_class_name, config, input_ids, attention_mask
|
||||
)
|
||||
|
||||
def test_use_cache_generation(self):
|
||||
config, input_ids, _ = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_use_cache_generation(config, input_ids)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("gpt2", from_pt=True)
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
Loading…
Reference in New Issue
Block a user