mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #2495 from mschrimpf/patch-1
T5: move rp_bucket to relative_attention_bias' device
This commit is contained in:
commit
b1e1a9f9b2
@ -286,6 +286,7 @@ class T5Attention(nn.Module):
|
||||
bidirectional=not self.is_decoder,
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
)
|
||||
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
|
||||
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
||||
return values
|
||||
|
Loading…
Reference in New Issue
Block a user