mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Granite speech speedup + model saving bugfix (#39028)
* ensure the query is updated during training
avoid unused parameters that DDP does not like
* avoid a crash when `kwargs` contain `padding=True`
trainers often pass this argument automatically
* minor
* Remove mel_spec lazy init, and rename to mel_filters.
this ensures save_pretrained will not crash when saving the processor during training
d5d007a1a0/src/transformers/feature_extraction_utils.py (L595)
* minor - most feature extractors has a `sampling_rate` property
* speedup relative position embeddings
* fix several issues in model saving/loading:
- avoid modifying `self._hf_peft_config_loaded` when saving
- adapter_config automatically points to the original base model - a finetuned version should point to the model save dir.
- fixing model weights names, that are changed by adding an adapter.
* minor
* minor
* minor
* fixing a crash without peft active
* add todo to replace einsum
This commit is contained in:
parent
1d45d90e5d
commit
22b0a89878
@ -159,8 +159,12 @@ class GraniteSpeechConformerAttention(nn.Module):
|
||||
# shaw's relative positional embedding
|
||||
dist = attention_dists.to(hidden_states.device)
|
||||
rel_pos_emb = self.rel_pos_emb(dist)
|
||||
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
|
||||
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
|
||||
# alternative computation of `pos_attn` - for readability
|
||||
# rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
|
||||
# pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
|
||||
# einsum implementation of pos_attn - gives x30 speedup over the alternative
|
||||
# TODO (@avihu111) find a fast alternative to einsum
|
||||
pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale
|
||||
|
||||
if remainder > 0:
|
||||
# masked attention in the extended block
|
||||
@ -541,17 +545,34 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
|
||||
self.disable_adapters()
|
||||
return super().generate(*args, input_features=input_features, **kwargs)
|
||||
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
def save_pretrained(self, save_directory, *args, **kwargs):
|
||||
# overwrite save_pretrained to first save the adapter if we have one
|
||||
# NOTE - this will use the base model path we are exporting in the lora
|
||||
# adapter, which may not necessarily be the best behavior, but for now
|
||||
# we keep this for portability, since using the local dir causes problems
|
||||
# if the model is loaded from outside of the current working dir.
|
||||
if is_peft_available and self._hf_peft_config_loaded:
|
||||
super().save_pretrained(*args, **kwargs)
|
||||
adapter_name = self._get_adapter_name()
|
||||
self.peft_config[adapter_name].base_model_name_or_path = save_directory
|
||||
super().save_pretrained(save_directory, *args, **kwargs)
|
||||
# Then save the base model afterwards
|
||||
prev_val = self._hf_peft_config_loaded
|
||||
self._hf_peft_config_loaded = False
|
||||
super().save_pretrained(*args, **kwargs)
|
||||
super().save_pretrained(save_directory, *args, **kwargs)
|
||||
self._hf_peft_config_loaded = prev_val
|
||||
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_save(key) -> tuple[str, bool]:
|
||||
# save the model with the original weights format
|
||||
return key.replace(".base_layer", ""), False
|
||||
|
||||
def _fix_state_dict_keys_on_save(self, state_dict):
|
||||
if is_peft_available and self._hf_peft_config_loaded:
|
||||
# state dict is only adapter, should keep the same
|
||||
return state_dict
|
||||
# rename back the base model state dict
|
||||
return {
|
||||
self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items() if ".lora_" not in key
|
||||
}
|
||||
|
||||
def _get_adapter_name(self):
|
||||
return list(self.peft_config.keys())[0]
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
Loading…
Reference in New Issue
Block a user