mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[GPT-J
] Fix causal mask dtype (#23147)
* fix #23136 * better fix * same fix for `masked_bias`
This commit is contained in:
parent
83b38fbea8
commit
57ffd8ab4c
@ -89,8 +89,9 @@ class GPTJAttention(nn.Module):
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
@ -732,7 +733,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
GPTJ_START_DOCSTRING,
|
||||
)
|
||||
class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
Loading…
Reference in New Issue
Block a user