mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add gradient_checkpointing parameter to FlaxWhisperEncoder (#23300)
Add gradient_checkpointing parameter
This commit is contained in:
parent
83eda6435e
commit
ab96bf0294
@ -1515,7 +1515,9 @@ class FlaxWhisperForAudioClassificationModule(nn.Module):
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self) -> None:
|
||||
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxWhisperEncoder(
|
||||
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.config.is_encoder_decoder = False
|
||||
num_layers = self.config.num_hidden_layers + 1
|
||||
if self.config.use_weighted_layer_sum:
|
||||
|
Loading…
Reference in New Issue
Block a user