mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add gradient checkpointing to Whisper Flax (#22954)
* Add gradient checkpointing to Whisper Flax * self.gradient_checkpointing only needed in nn.Module, removing unnecessary comments
This commit is contained in:
parent
a72b82ebe6
commit
ba0dc54576
@ -23,6 +23,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen import partitioning as nn_partitioning
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
@ -53,6 +54,8 @@ logger = logging.get_logger(__name__)
|
||||
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
|
||||
_CONFIG_FOR_DOC = "WhisperConfig"
|
||||
|
||||
remat = nn_partitioning.remat
|
||||
|
||||
|
||||
WHISPER_START_DOCSTRING = r"""
|
||||
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -387,16 +390,23 @@ class FlaxWhisperEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayerCollection with MBart->Whisper
|
||||
class FlaxWhisperEncoderLayerCollection(nn.Module):
|
||||
config: WhisperConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.encoder_layers)
|
||||
]
|
||||
if self.gradient_checkpointing:
|
||||
FlaxWhisperEncoderCheckpointLayer = remat(FlaxWhisperEncoderLayer, static_argnums=(2, 3))
|
||||
self.layers = [
|
||||
FlaxWhisperEncoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.encoder_layers)
|
||||
]
|
||||
else:
|
||||
self.layers = [
|
||||
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.encoder_layers)
|
||||
]
|
||||
self.layerdrop = self.config.encoder_layerdrop
|
||||
|
||||
def __call__(
|
||||
@ -531,16 +541,23 @@ class FlaxWhisperDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayerCollection with MBart->Whisper
|
||||
class FlaxWhisperDecoderLayerCollection(nn.Module):
|
||||
config: WhisperConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.decoder_layers)
|
||||
]
|
||||
if self.gradient_checkpointing:
|
||||
FlaxWhisperDecoderCheckpointLayer = remat(FlaxWhisperDecoderLayer, static_argnums=(4, 5, 6))
|
||||
self.layers = [
|
||||
FlaxWhisperDecoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.decoder_layers)
|
||||
]
|
||||
else:
|
||||
self.layers = [
|
||||
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.decoder_layers)
|
||||
]
|
||||
self.layerdrop = self.config.decoder_layerdrop
|
||||
|
||||
def __call__(
|
||||
@ -570,12 +587,12 @@ class FlaxWhisperDecoderLayerCollection(nn.Module):
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
deterministic=deterministic,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
init_cache,
|
||||
output_attentions,
|
||||
deterministic,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -605,6 +622,7 @@ class FlaxWhisperDecoderLayerCollection(nn.Module):
|
||||
class FlaxWhisperEncoder(nn.Module):
|
||||
config: WhisperConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self) -> None:
|
||||
self.conv1 = nn.Conv(
|
||||
@ -628,6 +646,7 @@ class FlaxWhisperEncoder(nn.Module):
|
||||
self.layers = FlaxWhisperEncoderLayerCollection(
|
||||
self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)
|
||||
|
||||
@ -689,12 +708,15 @@ class FlaxWhisperEncoder(nn.Module):
|
||||
class FlaxWhisperDecoder(nn.Module):
|
||||
config: WhisperConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self) -> None:
|
||||
self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype)
|
||||
self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype)
|
||||
|
||||
self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype)
|
||||
self.layers = FlaxWhisperDecoderLayerCollection(
|
||||
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
|
||||
@ -753,10 +775,15 @@ class FlaxWhisperDecoder(nn.Module):
|
||||
class FlaxWhisperModule(nn.Module):
|
||||
config: WhisperConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self) -> None:
|
||||
self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype)
|
||||
self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxWhisperEncoder(
|
||||
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.decoder = FlaxWhisperDecoder(
|
||||
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -821,11 +848,19 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self._module = self.module_class(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_features = jnp.zeros(input_shape, dtype="f4")
|
||||
@ -1137,9 +1172,12 @@ append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqM
|
||||
class FlaxWhisperForConditionalGenerationModule(nn.Module):
|
||||
config: WhisperConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self) -> None:
|
||||
self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype)
|
||||
self.model = FlaxWhisperModule(
|
||||
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.lm_head = nn.Dense(
|
||||
self.config.vocab_size,
|
||||
use_bias=False,
|
||||
|
Loading…
Reference in New Issue
Block a user