mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix GPT-2 warnings (#11213)
* Fix GPT-2 warnings * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
parent
0cd89d8c83
commit
823df93955
@ -380,6 +380,16 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
|
||||
missing_keys += missing_keys_pt
|
||||
|
||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||
# the user.
|
||||
if pt_model._keys_to_ignore_on_load_missing is not None:
|
||||
for pat in pt_model._keys_to_ignore_on_load_missing:
|
||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||
|
||||
if pt_model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in pt_model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the TF 2.0 model were not used when "
|
||||
|
@ -802,7 +802,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
GPT2_START_DOCSTRING,
|
||||
)
|
||||
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
Loading…
Reference in New Issue
Block a user