Granite speech speedup + model saving bugfix (#39028)

* ensure the query is updated during training

avoid unused parameters that DDP does not like

* avoid a crash when `kwargs` contain `padding=True`

trainers often pass this argument automatically

* minor

* Remove mel_spec lazy init, and rename to mel_filters.
this ensures save_pretrained will not crash when saving the processor during training
d5d007a1a0/src/transformers/feature_extraction_utils.py (L595)

* minor - most feature extractors has a `sampling_rate` property

* speedup relative position embeddings

* fix several issues in model saving/loading:
- avoid modifying `self._hf_peft_config_loaded` when saving
- adapter_config automatically points to the original base model - a finetuned version should point to the model save dir.
- fixing model weights names, that are changed by adding an adapter.

* minor

* minor

* minor

* fixing a crash without peft active

* add todo to replace einsum
This commit is contained in:
Avihu Dekel 2025-06-26 10:44:17 +03:00 committed by GitHub
parent 1d45d90e5d
commit 22b0a89878
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -159,8 +159,12 @@ class GraniteSpeechConformerAttention(nn.Module):
# shaw's relative positional embedding
dist = attention_dists.to(hidden_states.device)
rel_pos_emb = self.rel_pos_emb(dist)
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
# alternative computation of `pos_attn` - for readability
# rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
# pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
# einsum implementation of pos_attn - gives x30 speedup over the alternative
# TODO (@avihu111) find a fast alternative to einsum
pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale
if remainder > 0:
# masked attention in the extended block
@ -541,17 +545,34 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
self.disable_adapters()
return super().generate(*args, input_features=input_features, **kwargs)
def save_pretrained(self, *args, **kwargs):
def save_pretrained(self, save_directory, *args, **kwargs):
# overwrite save_pretrained to first save the adapter if we have one
# NOTE - this will use the base model path we are exporting in the lora
# adapter, which may not necessarily be the best behavior, but for now
# we keep this for portability, since using the local dir causes problems
# if the model is loaded from outside of the current working dir.
if is_peft_available and self._hf_peft_config_loaded:
super().save_pretrained(*args, **kwargs)
adapter_name = self._get_adapter_name()
self.peft_config[adapter_name].base_model_name_or_path = save_directory
super().save_pretrained(save_directory, *args, **kwargs)
# Then save the base model afterwards
prev_val = self._hf_peft_config_loaded
self._hf_peft_config_loaded = False
super().save_pretrained(*args, **kwargs)
super().save_pretrained(save_directory, *args, **kwargs)
self._hf_peft_config_loaded = prev_val
@staticmethod
def _fix_state_dict_key_on_save(key) -> tuple[str, bool]:
# save the model with the original weights format
return key.replace(".base_layer", ""), False
def _fix_state_dict_keys_on_save(self, state_dict):
if is_peft_available and self._hf_peft_config_loaded:
# state dict is only adapter, should keep the same
return state_dict
# rename back the base model state dict
return {
self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items() if ".lora_" not in key
}
def _get_adapter_name(self):
return list(self.peft_config.keys())[0]
__all__ = [