Fix embeddings for PyTorch 1.8 (#10549)

* Fix embeddings for PyTorch 1.8

* Try with PyTorch 1.8.0

* Fix embeddings init

* Fix copies

* Typo

* More typos
This commit is contained in:
Sylvain Gugger 2021-03-05 16:18:48 -05:00 committed by GitHub
parent 3e056c1003
commit 7da995c00c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 171 additions and 78 deletions

View File

@ -79,8 +79,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
- run: pip install -U torch==1.7.1
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
@ -107,8 +106,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
- run: pip install -U torch==1.7.1
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
@ -187,8 +185,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
- run: pip install -U torch==1.7.1
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:

View File

@ -490,12 +490,16 @@ class AlbertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear)) and module.bias is not None:
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

View File

@ -704,15 +704,19 @@ class BertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass

View File

@ -177,15 +177,19 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
BERT_GENERATION_START_DOCSTRING = r"""

View File

@ -238,15 +238,19 @@ class ConvBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, (nn.Linear)) and module.bias is not None:
module.bias.data.zero_()
class SeparableConv1D(nn.Module):

View File

@ -221,12 +221,16 @@ class CTRLPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
if isinstance(module, (nn.Linear, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

View File

@ -767,13 +767,17 @@ class DebertaPreTrainedModel(PreTrainedModel):
self._register_load_state_dict_pre_hook(self._pre_load_hook)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
"""

View File

@ -886,13 +886,17 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
self._register_load_state_dict_pre_hook(self._pre_load_hook)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
"""

View File

@ -341,16 +341,19 @@ class DistilBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Embedding):
if module.weight.requires_grad:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
DISTILBERT_START_DOCSTRING = r"""

View File

@ -653,15 +653,19 @@ class ElectraPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass

View File

@ -779,6 +779,8 @@ class FunnelPreTrainedModel(PreTrainedModel):
elif classname == "FunnelEmbeddings":
std = 1.0 if self.config.initializer_std is None else self.config.initializer_std
nn.init.normal_(module.word_embeddings.weight, std=std)
if module.word_embeddings.padding_idx is not None:
module.word_embeddings.weight.data[module.padding_idx].zero_()
class FunnelClassificationHead(nn.Module):

View File

@ -345,12 +345,16 @@ class GPT2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
if isinstance(module, (nn.Linear, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

View File

@ -645,15 +645,19 @@ class IBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (QuantLinear, QuantEmbedding, nn.Linear, nn.Embedding)):
if isinstance(module, (QuantLinear, nn.Linear)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (QuantEmbedding, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (IntLayerNorm, nn.LayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, (QuantLinear, nn.Linear)) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None):
raise NotImplementedError("`resize_token_embeddings` is not supported for I-BERT.")

View File

@ -612,15 +612,19 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LayoutLMLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
LAYOUTLM_START_DOCSTRING = r"""

View File

@ -1363,15 +1363,19 @@ class LongformerPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
LONGFORMER_START_DOCSTRING = r"""

View File

@ -783,15 +783,19 @@ class LxmertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
LXMERT_START_DOCSTRING = r"""

View File

@ -670,15 +670,19 @@ class MobileBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (nn.LayerNorm, NoNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass

View File

@ -56,15 +56,19 @@ class MPNetPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class MPNetEmbeddings(nn.Module):

View File

@ -283,12 +283,16 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
if isinstance(module, (nn.Linear, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

View File

@ -1791,16 +1791,17 @@ class ReformerPreTrainedModel(PreTrainedModel):
torch.nn.init.normal_(weight, std=self.config.axial_norm_std)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass

View File

@ -51,13 +51,17 @@ class RetriBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
RETRIBERT_START_DOCSTRING = r"""

View File

@ -574,15 +574,19 @@ class RobertaPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
ROBERTA_START_DOCSTRING = r"""

View File

@ -432,15 +432,19 @@ class SqueezeBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Embedding)):
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, SqueezeBertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
SQUEEZEBERT_START_DOCSTRING = r"""

View File

@ -700,15 +700,19 @@ class TapasPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
TAPAS_START_DOCSTRING = r"""

View File

@ -254,10 +254,12 @@ class XLMPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Embedding):
if self.config is not None and self.config.embed_init_std is not None:
nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, nn.Linear):
if self.config is not None and self.config.init_std is not None:
nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
if hasattr(module, "bias") and module.bias is not None:
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()

View File

@ -552,12 +552,16 @@ class XLNetPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

View File

@ -656,15 +656,19 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""