Deberta V2: Fix critical trace warnings to allow ONNX export (#18272)

* Fix critical trace warnings to allow ONNX export

* Force input to `sqrt` to be float type

* Cleanup code

* Remove unused import statement

* Update model sew

* Small refactor

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Use broadcasting instead of repeat

* Implement suggestion

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Match deberta v2 changes in sew_d

* Improve code quality

* Update code quality

* Consistency of small refactor

* Match changes in sew_d

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
This commit is contained in:
iiLaurens 2022-08-11 15:54:43 +02:00 committed by GitHub
parent 5d3f037433
commit d53dffec6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 24 deletions

View File

@ -14,11 +14,9 @@
# limitations under the License. # limitations under the License.
""" PyTorch DeBERTa-v2 model.""" """ PyTorch DeBERTa-v2 model."""
import math
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
@ -552,11 +550,17 @@ class DebertaV2Encoder(nn.Module):
def make_log_bucket_position(relative_pos, bucket_size, max_position): def make_log_bucket_position(relative_pos, bucket_size, max_position):
sign = np.sign(relative_pos) sign = torch.sign(relative_pos)
mid = bucket_size // 2 mid = bucket_size // 2
abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) abs_pos = torch.where(
log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid (relative_pos < mid) & (relative_pos > -mid),
bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) torch.tensor(mid - 1).type_as(relative_pos),
torch.abs(relative_pos),
)
log_pos = (
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
)
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
return bucket_pos return bucket_pos
@ -578,12 +582,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
`torch.LongTensor`: A tensor with shape [1, query_size, key_size] `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
""" """
q_ids = np.arange(0, query_size) q_ids = torch.arange(0, query_size)
k_ids = np.arange(0, key_size) k_ids = torch.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0: if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) rel_pos_ids = rel_pos_ids.to(torch.long)
rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0) rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids return rel_pos_ids
@ -712,7 +716,7 @@ class DisentangledSelfAttention(nn.Module):
scale_factor += 1 scale_factor += 1
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = math.sqrt(query_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
@ -787,7 +791,7 @@ class DisentangledSelfAttention(nn.Module):
score = 0 score = 0
# content->position # content->position
if "c2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather( c2p_att = torch.gather(
@ -799,7 +803,7 @@ class DisentangledSelfAttention(nn.Module):
# position->content # position->content
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2): if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position( r_pos = build_relative_position(
key_layer.size(-2), key_layer.size(-2),

View File

@ -194,11 +194,17 @@ def _compute_mask_indices(
# Copied from transformers.models.deberta_v2.modeling_deberta_v2.make_log_bucket_position # Copied from transformers.models.deberta_v2.modeling_deberta_v2.make_log_bucket_position
def make_log_bucket_position(relative_pos, bucket_size, max_position): def make_log_bucket_position(relative_pos, bucket_size, max_position):
sign = np.sign(relative_pos) sign = torch.sign(relative_pos)
mid = bucket_size // 2 mid = bucket_size // 2
abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) abs_pos = torch.where(
log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid (relative_pos < mid) & (relative_pos > -mid),
bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) torch.tensor(mid - 1).type_as(relative_pos),
torch.abs(relative_pos),
)
log_pos = (
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
)
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
return bucket_pos return bucket_pos
@ -221,12 +227,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
`torch.LongTensor`: A tensor with shape [1, query_size, key_size] `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
""" """
q_ids = np.arange(0, query_size) q_ids = torch.arange(0, query_size)
k_ids = np.arange(0, key_size) k_ids = torch.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0: if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) rel_pos_ids = rel_pos_ids.to(torch.long)
rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0) rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids return rel_pos_ids
@ -784,7 +790,7 @@ class DisentangledSelfAttention(nn.Module):
scale_factor += 1 scale_factor += 1
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = math.sqrt(query_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
@ -859,7 +865,7 @@ class DisentangledSelfAttention(nn.Module):
score = 0 score = 0
# content->position # content->position
if "c2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather( c2p_att = torch.gather(
@ -871,7 +877,7 @@ class DisentangledSelfAttention(nn.Module):
# position->content # position->content
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2): if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position( r_pos = build_relative_position(
key_layer.size(-2), key_layer.size(-2),