mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[OPT/Galactica] Load large galactica
models (#20390)
* fix `opt` bias * revert unneeded assignment
This commit is contained in:
parent
293991d44b
commit
b75255cd9d
@ -74,6 +74,10 @@ class OPTConfig(PretrainedConfig):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
enable_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not if the linear layers in the attention blocks should use the bias term.
|
||||
layer_norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not if the layer norms should have learnable parameters.
|
||||
|
||||
Example:
|
||||
|
||||
@ -112,6 +116,8 @@ class OPTConfig(PretrainedConfig):
|
||||
pad_token_id=1,
|
||||
bos_token_id=2,
|
||||
eos_token_id=2,
|
||||
enable_bias=True,
|
||||
layer_norm_elementwise_affine=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@ -134,6 +140,9 @@ class OPTConfig(PretrainedConfig):
|
||||
self.layerdrop = layerdrop
|
||||
self.use_cache = use_cache
|
||||
self.do_layer_norm_before = do_layer_norm_before
|
||||
# We keep these variables at `True` for backward compatibility.
|
||||
self.enable_bias = enable_bias
|
||||
self.layer_norm_elementwise_affine = layer_norm_elementwise_affine
|
||||
|
||||
# Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
|
@ -279,15 +279,18 @@ class OPTDecoderLayer(nn.Module):
|
||||
num_heads=config.num_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
bias=config.enable_bias,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
|
||||
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
|
||||
)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
|
||||
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -507,7 +510,9 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
|
||||
)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user