mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
PegasusX add _no_split_modules (#25933)
* no_split_modules * no_split_modules * inputs_embeds+pos same device * update _no_split_modules * update _no_split_modules
This commit is contained in:
parent
70a98024b1
commit
da1af21dbb
@ -769,6 +769,7 @@ class PegasusXPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = PegasusXConfig
|
config_class = PegasusXConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
@ -1299,6 +1300,8 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
|
|||||||
# embed positions
|
# embed positions
|
||||||
positions = self.embed_positions(inputs_embeds, past_key_values_length)
|
positions = self.embed_positions(inputs_embeds, past_key_values_length)
|
||||||
|
|
||||||
|
positions = positions.to(inputs_embeds.device)
|
||||||
|
|
||||||
hidden_states = inputs_embeds + positions
|
hidden_states = inputs_embeds + positions
|
||||||
|
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
|
Loading…
Reference in New Issue
Block a user