Add gradient checkpointing to Whisper Flax (#22954)

* Add gradient checkpointing to Whisper Flax

* self.gradient_checkpointing only needed in nn.Module, removing unnecessary comments
This commit is contained in:
Javier de la Rosa 2023-04-26 18:19:16 +02:00 committed by GitHub
parent a72b82ebe6
commit ba0dc54576
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,6 +23,7 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
@ -53,6 +54,8 @@ logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_CONFIG_FOR_DOC = "WhisperConfig"
remat = nn_partitioning.remat
WHISPER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -387,16 +390,23 @@ class FlaxWhisperEncoderLayer(nn.Module):
return outputs
# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayerCollection with MBart->Whisper
class FlaxWhisperEncoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.layers = [
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
if self.gradient_checkpointing:
FlaxWhisperEncoderCheckpointLayer = remat(FlaxWhisperEncoderLayer, static_argnums=(2, 3))
self.layers = [
FlaxWhisperEncoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
else:
self.layers = [
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop
def __call__(
@ -531,16 +541,23 @@ class FlaxWhisperDecoderLayer(nn.Module):
return outputs
# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayerCollection with MBart->Whisper
class FlaxWhisperDecoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.layers = [
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
if self.gradient_checkpointing:
FlaxWhisperDecoderCheckpointLayer = remat(FlaxWhisperDecoderLayer, static_argnums=(4, 5, 6))
self.layers = [
FlaxWhisperDecoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
else:
self.layers = [
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
self.layerdrop = self.config.decoder_layerdrop
def __call__(
@ -570,12 +587,12 @@ class FlaxWhisperDecoderLayerCollection(nn.Module):
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
output_attentions,
deterministic,
)
hidden_states = layer_outputs[0]
@ -605,6 +622,7 @@ class FlaxWhisperDecoderLayerCollection(nn.Module):
class FlaxWhisperEncoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self) -> None:
self.conv1 = nn.Conv(
@ -628,6 +646,7 @@ class FlaxWhisperEncoder(nn.Module):
self.layers = FlaxWhisperEncoderLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)
@ -689,12 +708,15 @@ class FlaxWhisperEncoder(nn.Module):
class FlaxWhisperDecoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self) -> None:
self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype)
self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype)
self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype)
self.layers = FlaxWhisperDecoderLayerCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
@ -753,10 +775,15 @@ class FlaxWhisperDecoder(nn.Module):
class FlaxWhisperModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype)
self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype)
self.encoder = FlaxWhisperEncoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.decoder = FlaxWhisperDecoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
def __call__(
self,
@ -821,11 +848,19 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
@ -1137,9 +1172,12 @@ append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqM
class FlaxWhisperForConditionalGenerationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self) -> None:
self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype)
self.model = FlaxWhisperModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,