[fix] no warning for position_ids buffer (#6063)

This commit is contained in:
Sam Shleifer 2020-07-27 20:00:44 -04:00 committed by GitHub
parent 1e00ef681d
commit b7345d22d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 1 deletions

View File

@ -699,6 +699,8 @@ class BertModel(BertPreTrainedModel):
"""
authorized_missing_keys = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.config = config

View File

@ -788,6 +788,8 @@ class MobileBertModel(MobileBertPreTrainedModel):
https://arxiv.org/pdf/2004.02984.pdf
"""
authorized_missing_keys = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.config = config

View File

@ -272,6 +272,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
config_class = OpenAIGPTConfig
load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer"
authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights.

View File

@ -375,7 +375,9 @@ XLM_INPUTS_DOCSTRING = r"""
XLM_START_DOCSTRING,
)
class XLMModel(XLMPreTrainedModel):
def __init__(self, config): # , dico, is_encoder, with_output):
authorized_missing_keys = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
# encoder / decoder, output layer