mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Fix embeddings for PyTorch 1.8 (#10549)
* Fix embeddings for PyTorch 1.8 * Try with PyTorch 1.8.0 * Fix embeddings init * Fix copies * Typo * More typos
This commit is contained in:
parent
3e056c1003
commit
7da995c00c
@ -79,8 +79,7 @@ jobs:
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
|
||||
- run: pip install -U torch==1.7.1
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -107,8 +106,7 @@ jobs:
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,torch,testing,sentencepiece]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
|
||||
- run: pip install -U torch==1.7.1
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -187,8 +185,7 @@ jobs:
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,torch,testing,sentencepiece]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
|
||||
- run: pip install -U torch==1.7.1
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
|
@ -490,12 +490,16 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, (nn.Linear)) and module.bias is not None:
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
@ -704,15 +704,19 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -177,15 +177,19 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
BERT_GENERATION_START_DOCSTRING = r"""
|
||||
|
@ -238,15 +238,19 @@ class ConvBertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, (nn.Linear)) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class SeparableConv1D(nn.Module):
|
||||
|
@ -221,12 +221,16 @@ class CTRLPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
||||
if isinstance(module, (nn.Linear, Conv1D)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
@ -767,13 +767,17 @@ class DebertaPreTrainedModel(PreTrainedModel):
|
||||
self._register_load_state_dict_pre_hook(self._pre_load_hook)
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
"""
|
||||
|
@ -886,13 +886,17 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
|
||||
self._register_load_state_dict_pre_hook(self._pre_load_hook)
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
"""
|
||||
|
@ -341,16 +341,19 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Embedding):
|
||||
if module.weight.requires_grad:
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
DISTILBERT_START_DOCSTRING = r"""
|
||||
|
@ -653,15 +653,19 @@ class ElectraPreTrainedModel(PreTrainedModel):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -779,6 +779,8 @@ class FunnelPreTrainedModel(PreTrainedModel):
|
||||
elif classname == "FunnelEmbeddings":
|
||||
std = 1.0 if self.config.initializer_std is None else self.config.initializer_std
|
||||
nn.init.normal_(module.word_embeddings.weight, std=std)
|
||||
if module.word_embeddings.padding_idx is not None:
|
||||
module.word_embeddings.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class FunnelClassificationHead(nn.Module):
|
||||
|
@ -345,12 +345,16 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
||||
if isinstance(module, (nn.Linear, Conv1D)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
@ -645,15 +645,19 @@ class IBertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (QuantLinear, QuantEmbedding, nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, (QuantLinear, nn.Linear)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (QuantEmbedding, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, (IntLayerNorm, nn.LayerNorm)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, (QuantLinear, nn.Linear)) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
raise NotImplementedError("`resize_token_embeddings` is not supported for I-BERT.")
|
||||
|
@ -612,15 +612,19 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, LayoutLMLayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
LAYOUTLM_START_DOCSTRING = r"""
|
||||
|
@ -1363,15 +1363,19 @@ class LongformerPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
LONGFORMER_START_DOCSTRING = r"""
|
||||
|
@ -783,15 +783,19 @@ class LxmertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
LXMERT_START_DOCSTRING = r"""
|
||||
|
@ -670,15 +670,19 @@ class MobileBertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, (nn.LayerNorm, NoNorm)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -56,15 +56,19 @@ class MPNetPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class MPNetEmbeddings(nn.Module):
|
||||
|
@ -283,12 +283,16 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
||||
if isinstance(module, (nn.Linear, Conv1D)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
@ -1791,16 +1791,17 @@ class ReformerPreTrainedModel(PreTrainedModel):
|
||||
torch.nn.init.normal_(weight, std=self.config.axial_norm_std)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -51,13 +51,17 @@ class RetriBertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
RETRIBERT_START_DOCSTRING = r"""
|
||||
|
@ -574,15 +574,19 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
ROBERTA_START_DOCSTRING = r"""
|
||||
|
@ -432,15 +432,19 @@ class SqueezeBertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Embedding)):
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, SqueezeBertLayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
SQUEEZEBERT_START_DOCSTRING = r"""
|
||||
|
@ -700,15 +700,19 @@ class TapasPreTrainedModel(PreTrainedModel):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
TAPAS_START_DOCSTRING = r"""
|
||||
|
@ -254,10 +254,12 @@ class XLMPreTrainedModel(PreTrainedModel):
|
||||
if isinstance(module, nn.Embedding):
|
||||
if self.config is not None and self.config.embed_init_std is not None:
|
||||
nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, nn.Linear):
|
||||
if self.config is not None and self.config.init_std is not None:
|
||||
nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
|
@ -552,12 +552,16 @@ class XLNetPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
@ -656,15 +656,19 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
||||
|
Loading…
Reference in New Issue
Block a user