Merge pull request #2495 from mschrimpf/patch-1

T5: move rp_bucket to relative_attention_bias' device
This commit is contained in:
Thomas Wolf 2020-01-10 22:18:54 +01:00 committed by GitHub
commit b1e1a9f9b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -286,6 +286,7 @@ class T5Attention(nn.Module):
bidirectional=not self.is_decoder,
num_buckets=self.relative_attention_num_buckets,
)
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
return values