Add gradient_checkpointing parameter to FlaxWhisperEncoder (#23300)

Add gradient_checkpointing parameter
This commit is contained in:
raghavanone 2023-05-11 23:43:05 +05:30 committed by GitHub
parent 83eda6435e
commit ab96bf0294
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1515,7 +1515,9 @@ class FlaxWhisperForAudioClassificationModule(nn.Module):
gradient_checkpointing: bool = False
def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
self.encoder = FlaxWhisperEncoder(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.config.is_encoder_decoder = False
num_layers = self.config.num_hidden_layers + 1
if self.config.use_weighted_layer_sum: