mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Granite speech - minor fixes to support training with the HF trainer (#38833)
* ensure the query is updated during training
avoid unused parameters that DDP does not like
* avoid a crash when `kwargs` contain `padding=True`
trainers often pass this argument automatically
* minor
* Remove mel_spec lazy init, and rename to mel_filters.
this ensures save_pretrained will not crash when saving the processor during training
d5d007a1a0/src/transformers/feature_extraction_utils.py (L595)
* minor - most feature extractors has a `sampling_rate` property
This commit is contained in:
parent
e1e11b0299
commit
be10d4df60
@ -50,6 +50,7 @@ class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.sampling_rate = sampling_rate
|
||||
self.melspec_kwargs = {
|
||||
"sample_rate": sampling_rate,
|
||||
"n_fft": n_fft,
|
||||
@ -57,8 +58,8 @@ class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
|
||||
"hop_length": hop_length,
|
||||
"n_mels": n_mels,
|
||||
}
|
||||
# Currently lazily initialized
|
||||
self.melspec = None
|
||||
requires_backends(self, ["torchaudio"])
|
||||
self.mel_filters = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
|
||||
self.projector_window_size = projector_window_size
|
||||
self.projector_downsample_rate = projector_downsample_rate
|
||||
|
||||
@ -91,34 +92,16 @@ class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
|
||||
).view(-1, 1)
|
||||
return BatchFeature(data=speech_inputs)
|
||||
|
||||
def _ensure_melspec_transform_is_initialized(self):
|
||||
"""
|
||||
Ensures the mel spectrogram transform on this instance is initialized.
|
||||
|
||||
We do this for now since some logging explodes since the mel spectrogram
|
||||
transform is not JSON serializable.
|
||||
"""
|
||||
requires_backends(self, ["torchaudio"])
|
||||
|
||||
if self.melspec is None:
|
||||
# TODO (@alex-jw-brooks / @eustlb) move this to common batch
|
||||
# feature extraction in audio utils once they are written!
|
||||
self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
|
||||
|
||||
def _extract_mel_spectrograms(self, audio: "torch.Tensor", device="cpu"):
|
||||
"""
|
||||
Compute the Mel features to be passed to the conformer encoder.
|
||||
"""
|
||||
requires_backends(self, ["torchaudio"])
|
||||
|
||||
# Initialize the mel spectrogram if isn't not already and
|
||||
# move the melspec / audio to the computation device.
|
||||
self._ensure_melspec_transform_is_initialized()
|
||||
if device is not None:
|
||||
melspec = self.melspec.to(device)
|
||||
melspec = self.mel_filters.to(device)
|
||||
audio = audio.to(device)
|
||||
else:
|
||||
melspec = self.melspec
|
||||
melspec = self.mel_filters
|
||||
|
||||
bsz = audio.shape[0]
|
||||
with torch.no_grad():
|
||||
|
@ -83,7 +83,7 @@ class GraniteSpeechEncoderProjector(nn.Module):
|
||||
hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim)
|
||||
|
||||
query_output = self.qformer(
|
||||
query_embeds=self.query.data,
|
||||
query_embeds=self.query,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
return_dict=True,
|
||||
|
@ -88,7 +88,9 @@ class GraniteSpeechProcessor(ProcessorMixin):
|
||||
else:
|
||||
audio_inputs = {}
|
||||
|
||||
text_inputs = self.tokenizer(prompt_strings, padding=True, **kwargs)
|
||||
if "padding" not in kwargs:
|
||||
kwargs["padding"] = True
|
||||
text_inputs = self.tokenizer(prompt_strings, **kwargs)
|
||||
return BatchFeature(data={**text_inputs, **audio_inputs})
|
||||
|
||||
def _get_validated_text(self, text: Union[str, list]) -> list[str]:
|
||||
|
Loading…
Reference in New Issue
Block a user