mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
commit
96e7ee7238
@ -63,7 +63,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
|
||||
scaled_attention_logits = matmul_qk / np.sqrt(dk)
|
||||
|
||||
if mask is not None:
|
||||
scaled_attention_logits += (mask * -1e4)
|
||||
nd, ns = scaled_attention_logits.size(-2), scaled_attention_logits.size(-1)
|
||||
scaled_attention_logits += (mask[ns-nd:ns, :ns] * -1e4)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
@ -373,7 +374,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
inputs_embeds = self.w(input_ids)
|
||||
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||
seq_len = input_shape[-1]
|
||||
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device)
|
||||
mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(inputs_embeds.device)
|
||||
|
||||
inputs_embeds *= np.sqrt(self.d_model_size)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user