diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index d30254ca62a..6e61f732b77 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -159,8 +159,12 @@ class GraniteSpeechConformerAttention(nn.Module): # shaw's relative positional embedding dist = attention_dists.to(hidden_states.device) rel_pos_emb = self.rel_pos_emb(dist) - rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) - pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale + # alternative computation of `pos_attn` - for readability + # rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) + # pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale + # einsum implementation of pos_attn - gives x30 speedup over the alternative + # TODO (@avihu111) find a fast alternative to einsum + pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale if remainder > 0: # masked attention in the extended block @@ -541,17 +545,34 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera self.disable_adapters() return super().generate(*args, input_features=input_features, **kwargs) - def save_pretrained(self, *args, **kwargs): + def save_pretrained(self, save_directory, *args, **kwargs): # overwrite save_pretrained to first save the adapter if we have one - # NOTE - this will use the base model path we are exporting in the lora - # adapter, which may not necessarily be the best behavior, but for now - # we keep this for portability, since using the local dir causes problems - # if the model is loaded from outside of the current working dir. if is_peft_available and self._hf_peft_config_loaded: - super().save_pretrained(*args, **kwargs) + adapter_name = self._get_adapter_name() + self.peft_config[adapter_name].base_model_name_or_path = save_directory + super().save_pretrained(save_directory, *args, **kwargs) # Then save the base model afterwards + prev_val = self._hf_peft_config_loaded self._hf_peft_config_loaded = False - super().save_pretrained(*args, **kwargs) + super().save_pretrained(save_directory, *args, **kwargs) + self._hf_peft_config_loaded = prev_val + + @staticmethod + def _fix_state_dict_key_on_save(key) -> tuple[str, bool]: + # save the model with the original weights format + return key.replace(".base_layer", ""), False + + def _fix_state_dict_keys_on_save(self, state_dict): + if is_peft_available and self._hf_peft_config_loaded: + # state dict is only adapter, should keep the same + return state_dict + # rename back the base model state dict + return { + self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items() if ".lora_" not in key + } + + def _get_adapter_name(self): + return list(self.peft_config.keys())[0] __all__ = [