From d53dffec6ef5f0cf28df3a1e7f70f1c5da5762ce Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Thu, 11 Aug 2022 15:54:43 +0200 Subject: [PATCH] Deberta V2: Fix critical trace warnings to allow ONNX export (#18272) * Fix critical trace warnings to allow ONNX export * Force input to `sqrt` to be float type * Cleanup code * Remove unused import statement * Update model sew * Small refactor Co-authored-by: Michael Benayoun * Use broadcasting instead of repeat * Implement suggestion Co-authored-by: Michael Benayoun * Match deberta v2 changes in sew_d * Improve code quality * Update code quality * Consistency of small refactor * Match changes in sew_d Co-authored-by: Michael Benayoun --- .../models/deberta_v2/modeling_deberta_v2.py | 30 +++++++++++-------- .../models/sew_d/modeling_sew_d.py | 28 ++++++++++------- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index a513a8280ed..3243ee108d4 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -14,11 +14,9 @@ # limitations under the License. """ PyTorch DeBERTa-v2 model.""" -import math from collections.abc import Sequence from typing import Optional, Tuple, Union -import numpy as np import torch import torch.utils.checkpoint from torch import nn @@ -552,11 +550,17 @@ class DebertaV2Encoder(nn.Module): def make_log_bucket_position(relative_pos, bucket_size, max_position): - sign = np.sign(relative_pos) + sign = torch.sign(relative_pos) mid = bucket_size // 2 - abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) - log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid - bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + abs_pos = torch.where( + (relative_pos < mid) & (relative_pos > -mid), + torch.tensor(mid - 1).type_as(relative_pos), + torch.abs(relative_pos), + ) + log_pos = ( + torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid + ) + bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign) return bucket_pos @@ -578,12 +582,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- `torch.LongTensor`: A tensor with shape [1, query_size, key_size] """ - q_ids = np.arange(0, query_size) - k_ids = np.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + q_ids = torch.arange(0, query_size) + k_ids = torch.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - k_ids[None, :] if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) - rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + rel_pos_ids = rel_pos_ids.to(torch.long) rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids.unsqueeze(0) return rel_pos_ids @@ -712,7 +716,7 @@ class DisentangledSelfAttention(nn.Module): scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = math.sqrt(query_layer.size(-1) * scale_factor) + scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) @@ -787,7 +791,7 @@ class DisentangledSelfAttention(nn.Module): score = 0 # content->position if "c2p" in self.pos_att_type: - scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_att = torch.gather( @@ -799,7 +803,7 @@ class DisentangledSelfAttention(nn.Module): # position->content if "p2c" in self.pos_att_type: - scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) if key_layer.size(-2) != query_layer.size(-2): r_pos = build_relative_position( key_layer.size(-2), diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index a9a231aec1d..fe5836a80f3 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -194,11 +194,17 @@ def _compute_mask_indices( # Copied from transformers.models.deberta_v2.modeling_deberta_v2.make_log_bucket_position def make_log_bucket_position(relative_pos, bucket_size, max_position): - sign = np.sign(relative_pos) + sign = torch.sign(relative_pos) mid = bucket_size // 2 - abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) - log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid - bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + abs_pos = torch.where( + (relative_pos < mid) & (relative_pos > -mid), + torch.tensor(mid - 1).type_as(relative_pos), + torch.abs(relative_pos), + ) + log_pos = ( + torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid + ) + bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign) return bucket_pos @@ -221,12 +227,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- `torch.LongTensor`: A tensor with shape [1, query_size, key_size] """ - q_ids = np.arange(0, query_size) - k_ids = np.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + q_ids = torch.arange(0, query_size) + k_ids = torch.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - k_ids[None, :] if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) - rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + rel_pos_ids = rel_pos_ids.to(torch.long) rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids.unsqueeze(0) return rel_pos_ids @@ -784,7 +790,7 @@ class DisentangledSelfAttention(nn.Module): scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = math.sqrt(query_layer.size(-1) * scale_factor) + scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) @@ -859,7 +865,7 @@ class DisentangledSelfAttention(nn.Module): score = 0 # content->position if "c2p" in self.pos_att_type: - scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_att = torch.gather( @@ -871,7 +877,7 @@ class DisentangledSelfAttention(nn.Module): # position->content if "p2c" in self.pos_att_type: - scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) if key_layer.size(-2) != query_layer.size(-2): r_pos = build_relative_position( key_layer.size(-2),