From 57ffd8ab4c833e26b2288769f6031f94870a102c Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 4 May 2023 16:31:19 +0200 Subject: [PATCH] [`GPT-J`] Fix causal mask dtype (#23147) * fix #23136 * better fix * same fix for `masked_bias` --- src/transformers/models/gptj/modeling_gptj.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 18985cb3bce..3a1f99dd713 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -89,8 +89,9 @@ class GPTJAttention(nn.Module): torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 1, 1, max_positions, max_positions ), + persistent=False, ) - self.register_buffer("masked_bias", torch.tensor(-1e9)) + self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) @@ -732,7 +733,7 @@ class GPTJModel(GPTJPreTrainedModel): GPTJ_START_DOCSTRING, ) class GPTJForCausalLM(GPTJPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"] + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"] def __init__(self, config): super().__init__(config)