Granite speech - minor fixes to support training with the HF trainer (#38833)

* ensure the query is updated during training

avoid unused parameters that DDP does not like

* avoid a crash when `kwargs` contain `padding=True`

trainers often pass this argument automatically

* minor

* Remove mel_spec lazy init, and rename to mel_filters.
this ensures save_pretrained will not crash when saving the processor during training
d5d007a1a0/src/transformers/feature_extraction_utils.py (L595)

* minor - most feature extractors has a `sampling_rate` property
This commit is contained in:
Avihu Dekel 2025-06-24 18:06:52 +03:00 committed by GitHub
parent e1e11b0299
commit be10d4df60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 24 deletions

View File

@ -50,6 +50,7 @@ class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
**kwargs,
):
super().__init__(**kwargs)
self.sampling_rate = sampling_rate
self.melspec_kwargs = {
"sample_rate": sampling_rate,
"n_fft": n_fft,
@ -57,8 +58,8 @@ class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
"hop_length": hop_length,
"n_mels": n_mels,
}
# Currently lazily initialized
self.melspec = None
requires_backends(self, ["torchaudio"])
self.mel_filters = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
self.projector_window_size = projector_window_size
self.projector_downsample_rate = projector_downsample_rate
@ -91,34 +92,16 @@ class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
).view(-1, 1)
return BatchFeature(data=speech_inputs)
def _ensure_melspec_transform_is_initialized(self):
"""
Ensures the mel spectrogram transform on this instance is initialized.
We do this for now since some logging explodes since the mel spectrogram
transform is not JSON serializable.
"""
requires_backends(self, ["torchaudio"])
if self.melspec is None:
# TODO (@alex-jw-brooks / @eustlb) move this to common batch
# feature extraction in audio utils once they are written!
self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
def _extract_mel_spectrograms(self, audio: "torch.Tensor", device="cpu"):
"""
Compute the Mel features to be passed to the conformer encoder.
"""
requires_backends(self, ["torchaudio"])
# Initialize the mel spectrogram if isn't not already and
# move the melspec / audio to the computation device.
self._ensure_melspec_transform_is_initialized()
if device is not None:
melspec = self.melspec.to(device)
melspec = self.mel_filters.to(device)
audio = audio.to(device)
else:
melspec = self.melspec
melspec = self.mel_filters
bsz = audio.shape[0]
with torch.no_grad():

View File

@ -83,7 +83,7 @@ class GraniteSpeechEncoderProjector(nn.Module):
hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim)
query_output = self.qformer(
query_embeds=self.query.data,
query_embeds=self.query,
encoder_hidden_states=hidden_states,
encoder_attention_mask=None,
return_dict=True,

View File

@ -88,7 +88,9 @@ class GraniteSpeechProcessor(ProcessorMixin):
else:
audio_inputs = {}
text_inputs = self.tokenizer(prompt_strings, padding=True, **kwargs)
if "padding" not in kwargs:
kwargs["padding"] = True
text_inputs = self.tokenizer(prompt_strings, **kwargs)
return BatchFeature(data={**text_inputs, **audio_inputs})
def _get_validated_text(self, text: Union[str, list]) -> list[str]: