[GPT-J] Fix causal mask dtype (#23147)

* fix #23136

* better fix

* same fix for `masked_bias`
This commit is contained in:
Younes Belkada 2023-05-04 16:31:19 +02:00 committed by GitHub
parent 83b38fbea8
commit 57ffd8ab4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -89,8 +89,9 @@ class GPTJAttention(nn.Module):
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
@ -732,7 +733,7 @@ class GPTJModel(GPTJPreTrainedModel):
GPTJ_START_DOCSTRING,
)
class GPTJForCausalLM(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def __init__(self, config):
super().__init__(config)