mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fix random attention for pytorch's bigbird/pegasus_bigbird (#23056)
* fix random attention usage for bigbird and pegasus_bigbird * remove staticmethod, update tests target valus * revert style changes
This commit is contained in:
parent
ef0c380c12
commit
6f8a02844a
@ -1052,9 +1052,8 @@ class BigBirdBlockSparseAttention(nn.Module):
|
||||
|
||||
return plan_from_length, plan_num_rand_blocks
|
||||
|
||||
@staticmethod
|
||||
def _bigbird_block_rand_mask(
|
||||
from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1
|
||||
self, from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1
|
||||
):
|
||||
"""
|
||||
Create adjacency list of random attention.
|
||||
@ -1077,6 +1076,9 @@ class BigBirdBlockSparseAttention(nn.Module):
|
||||
raise ValueError("Error the number of blocks needs to be same!")
|
||||
|
||||
rand_attn = np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32)
|
||||
# During inference (eval) no randomness
|
||||
if not self.training:
|
||||
return rand_attn
|
||||
middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32)
|
||||
last = to_seq_length // to_block_size - 1
|
||||
if last_idx > (2 * to_block_size):
|
||||
@ -1160,11 +1162,17 @@ class BigBirdBlockSparseAttention(nn.Module):
|
||||
plan_block_length = np.array(plan_from_length) // from_block_size
|
||||
# till when to follow plan
|
||||
max_plan_idx = plan_from_length.index(from_seq_length)
|
||||
|
||||
# Random Attention adjacency list
|
||||
rand_attn = [
|
||||
np.zeros((num_blocks, np.sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=np.int32)
|
||||
for i in range(num_heads)
|
||||
]
|
||||
# During inference (eval) no randomness
|
||||
if not self.training:
|
||||
for nh in range(num_heads):
|
||||
rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :]
|
||||
return rand_attn
|
||||
|
||||
# We will go iteratively over the plan blocks and pick random number of
|
||||
# Attention blocks from the legally allowed blocks
|
||||
@ -1353,7 +1361,6 @@ class BigBirdAttention(nn.Module):
|
||||
attn_weights.key = self.self.key
|
||||
self.self = attn_weights
|
||||
self.attention_type = value
|
||||
|
||||
if not self.training:
|
||||
self.self.eval()
|
||||
|
||||
@ -1380,7 +1387,6 @@ class BigBirdAttention(nn.Module):
|
||||
from_mask = from_mask.to(hidden_states.dtype)
|
||||
if to_mask is not None:
|
||||
to_mask = to_mask.to(hidden_states.dtype)
|
||||
|
||||
if self.attention_type == "original_full":
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
|
@ -873,9 +873,8 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
|
||||
|
||||
return plan_from_length, plan_num_rand_blocks
|
||||
|
||||
@staticmethod
|
||||
def _bigbird_block_rand_mask(
|
||||
from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1
|
||||
self, from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1
|
||||
):
|
||||
"""
|
||||
Create adjacency list of random attention.
|
||||
@ -898,6 +897,9 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
|
||||
raise ValueError("Error the number of blocks needs to be same!")
|
||||
|
||||
rand_attn = np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32)
|
||||
# During inference (eval) no randomness
|
||||
if not self.training:
|
||||
return rand_attn
|
||||
middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32)
|
||||
last = to_seq_length // to_block_size - 1
|
||||
if last_idx > (2 * to_block_size):
|
||||
@ -981,11 +983,17 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
|
||||
plan_block_length = np.array(plan_from_length) // from_block_size
|
||||
# till when to follow plan
|
||||
max_plan_idx = plan_from_length.index(from_seq_length)
|
||||
|
||||
# Random Attention adjacency list
|
||||
rand_attn = [
|
||||
np.zeros((num_blocks, np.sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=np.int32)
|
||||
for i in range(num_heads)
|
||||
]
|
||||
# During inference (eval) no randomness
|
||||
if not self.training:
|
||||
for nh in range(num_heads):
|
||||
rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :]
|
||||
return rand_attn
|
||||
|
||||
# We will go iteratively over the plan blocks and pick random number of
|
||||
# Attention blocks from the legally allowed blocks
|
||||
|
@ -20,7 +20,7 @@ import unittest
|
||||
from transformers import BigBirdConfig, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
@ -618,20 +618,6 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
else:
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@unittest.skip(
|
||||
reason="Current Pytorch implementation has bug with random attention -> it always uses it not matter if we are in eval/train mode"
|
||||
)
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@unittest.skip(
|
||||
reason="Current Pytorch implementation has bug with random attention -> it always uses it not matter if we are in eval/train mode"
|
||||
)
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@ -664,18 +650,19 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
expected_prediction_logits_slice = torch.tensor(
|
||||
[
|
||||
[-0.2420, -0.6048, -0.0614, 7.8422],
|
||||
[-0.0596, -0.0104, -1.8408, 9.3352],
|
||||
[1.0588, 0.7999, 5.0770, 8.7555],
|
||||
[-0.1385, -1.7199, -1.7613, 6.1094],
|
||||
[-0.5583, 0.0475, -0.2508, 7.4423],
|
||||
[0.7409, 1.4460, -0.7593, 7.7010],
|
||||
[1.9150, 3.1395, 5.8840, 9.3498],
|
||||
[-0.1854, -1.4640, -2.2052, 3.7968],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4)
|
||||
)
|
||||
|
||||
expected_seq_relationship_logits = torch.tensor([[58.8196, 56.3629]], device=torch_device)
|
||||
expected_seq_relationship_logits = torch.tensor([[46.9465, 47.9517]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4))
|
||||
|
||||
def test_inference_full_pretraining(self):
|
||||
@ -787,22 +774,23 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
blocked_mask, band_mask, from_mask, to_mask = model.create_masks_for_block_sparse_attn(
|
||||
attn_mask, config.block_size
|
||||
)
|
||||
|
||||
targeted_cl = torch.tensor(
|
||||
[
|
||||
[0.1874, 1.5260, 0.2335, -0.0473, -0.0961, 1.8384, -0.0141, 0.1250, 0.0085, -0.0048],
|
||||
[-0.0554, 0.0728, 0.1683, -0.1332, 0.1741, 0.1337, -0.2380, -0.1849, -0.0390, -0.0259],
|
||||
[-0.0419, 0.0767, 0.1591, -0.1399, 0.1789, 0.1257, -0.2406, -0.1772, -0.0261, -0.0079],
|
||||
[0.1860, 1.5172, 0.2326, -0.0473, -0.0953, 1.8291, -0.0147, 0.1245, 0.0082, -0.0046],
|
||||
[0.1879, 1.5296, 0.2335, -0.0471, -0.0975, 1.8433, -0.0136, 0.1260, 0.0086, -0.0054],
|
||||
[0.1854, 1.5147, 0.2334, -0.0480, -0.0956, 1.8250, -0.0149, 0.1222, 0.0082, -0.0060],
|
||||
[0.1859, 1.5184, 0.2334, -0.0474, -0.0955, 1.8297, -0.0143, 0.1234, 0.0079, -0.0054],
|
||||
[0.1885, 1.5336, 0.2335, -0.0467, -0.0979, 1.8481, -0.0130, 0.1269, 0.0085, -0.0049],
|
||||
[0.1881, 1.5305, 0.2335, -0.0471, -0.0976, 1.8445, -0.0135, 0.1262, 0.0086, -0.0053],
|
||||
[0.1852, 1.5148, 0.2333, -0.0480, -0.0949, 1.8254, -0.0151, 0.1225, 0.0079, -0.0055],
|
||||
[0.1877, 1.5292, 0.2335, -0.0470, -0.0972, 1.8431, -0.0135, 0.1259, 0.0084, -0.0052],
|
||||
[0.1874, 1.5261, 0.2334, -0.0472, -0.0968, 1.8393, -0.0140, 0.1251, 0.0084, -0.0052],
|
||||
[0.1853, 1.5151, 0.2331, -0.0478, -0.0948, 1.8256, -0.0154, 0.1228, 0.0086, -0.0052],
|
||||
[0.1867, 1.5233, 0.2334, -0.0475, -0.0965, 1.8361, -0.0139, 0.1247, 0.0084, -0.0054],
|
||||
[0.1870, 1.5248, 0.2333, -0.0483, -0.0952, 1.8359, -0.0142, 0.1239, 0.0083, -0.0045],
|
||||
[-0.0601, 0.1243, 0.1329, -0.1524, 0.2347, 0.0894, -0.2248, -0.2461, -0.0645, -0.0109],
|
||||
[-0.0418, 0.1463, 0.1290, -0.1638, 0.2489, 0.0799, -0.2341, -0.2406, -0.0524, 0.0106],
|
||||
[0.1859, 1.5182, 0.2324, -0.0473, -0.0952, 1.8295, -0.0148, 0.1242, 0.0080, -0.0045],
|
||||
[0.1879, 1.5300, 0.2334, -0.0480, -0.0967, 1.8428, -0.0137, 0.1256, 0.0087, -0.0050],
|
||||
[0.1852, 1.5149, 0.2330, -0.0492, -0.0936, 1.8236, -0.0154, 0.1210, 0.0080, -0.0048],
|
||||
[0.1857, 1.5186, 0.2331, -0.0484, -0.0940, 1.8285, -0.0148, 0.1224, 0.0077, -0.0045],
|
||||
[0.1884, 1.5336, 0.2334, -0.0469, -0.0974, 1.8477, -0.0132, 0.1266, 0.0085, -0.0046],
|
||||
[0.1881, 1.5308, 0.2334, -0.0479, -0.0969, 1.8438, -0.0136, 0.1258, 0.0088, -0.0050],
|
||||
[0.1849, 1.5143, 0.2329, -0.0491, -0.0930, 1.8230, -0.0156, 0.1209, 0.0074, -0.0047],
|
||||
[0.1878, 1.5299, 0.2333, -0.0472, -0.0967, 1.8434, -0.0137, 0.1257, 0.0084, -0.0048],
|
||||
[0.1873, 1.5260, 0.2333, -0.0478, -0.0961, 1.8383, -0.0142, 0.1245, 0.0083, -0.0048],
|
||||
[0.1849, 1.5145, 0.2327, -0.0491, -0.0935, 1.8237, -0.0156, 0.1215, 0.0083, -0.0046],
|
||||
[0.1866, 1.5232, 0.2332, -0.0488, -0.0950, 1.8342, -0.0143, 0.1237, 0.0084, -0.0047],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
@ -851,21 +839,22 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
expected_prediction = torch.tensor(
|
||||
[
|
||||
[-0.0213, -0.2213, -0.0061, 0.0687],
|
||||
[0.0977, 0.1858, 0.2374, 0.0483],
|
||||
[0.2112, -0.2524, 0.5793, 0.0967],
|
||||
[0.2473, -0.5070, -0.0630, 0.2174],
|
||||
[0.2885, 0.1139, 0.6071, 0.2991],
|
||||
[0.2328, -0.2373, 0.3648, 0.1058],
|
||||
[0.2517, -0.0689, 0.0555, 0.0880],
|
||||
[0.1021, -0.1495, -0.0635, 0.1891],
|
||||
[0.0591, -0.0722, 0.2243, 0.2432],
|
||||
[-0.2059, -0.2679, 0.3225, 0.6183],
|
||||
[0.2280, -0.2618, 0.1693, 0.0103],
|
||||
[0.0183, -0.1375, 0.2284, -0.1707],
|
||||
[0.1887, -0.0474, 0.2604, 0.1453],
|
||||
[0.0651, 0.1999, 0.1797, 0.1161],
|
||||
[0.2833, -0.3036, 0.6910, 0.1123],
|
||||
[0.2836, -0.4644, -0.0111, 0.1530],
|
||||
[0.3919, -0.2823, 0.4192, 0.1687],
|
||||
[0.2168, -0.1956, 0.4050, 0.0925],
|
||||
[0.2597, -0.0884, 0.1258, 0.1119],
|
||||
[0.1127, -0.1203, 0.1924, 0.2859],
|
||||
[0.1362, -0.1315, 0.2693, 0.1027],
|
||||
[-0.3169, -0.2266, 0.4419, 0.6740],
|
||||
[0.2366, -0.1452, 0.2589, 0.0579],
|
||||
[0.0358, -0.2021, 0.3112, -0.1392],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(prediction[0, 52:64, 320:324], expected_prediction, atol=1e-4))
|
||||
|
||||
def test_inference_question_answering(self):
|
||||
@ -908,11 +897,12 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
target_start_logits = torch.tensor(
|
||||
[[-8.9304, -10.3849, -14.4997, -9.6497, -13.9469, -7.8134, -8.9687, -13.3585, -9.7987, -13.8869, -9.2632, -8.9294, -13.6721, -7.3198, -9.5434, -11.2641, -14.3245, -9.5705, -12.7367, -8.6168, -11.083, -13.7573, -8.1151, -14.5329, -7.6876, -15.706, -12.8558, -9.1135, 8.0909, -3.1925, -11.5812, -9.4822], [-11.5595, -14.5591, -10.2978, -14.8445, -10.2092, -11.1899, -13.8356, -10.5644, -14.7706, -9.9841, -11.0052, -14.1862, -8.8173, -11.1098, -12.4686, -15.0531, -11.0196, -13.6614, -10.0236, -11.8151, -14.8744, -9.5123, -15.1605, -8.6472, -15.4184, -8.898, -9.6328, -7.0258, -11.3365, -14.4065, -10.2587, -8.9103]], # noqa: E231
|
||||
[[-8.5622, -9.6209, -14.3351, -8.7032, -11.8596, -7.7446, -9.6730, -13.6063, -8.9651, -11.7417, -8.2641, -8.7056, -13.4116, -5.6600, -8.8316, -10.4148, -12.2180, -7.7979, -12.5274, -6.0685, -10.3373, -11.3128, -6.6456, -14.4030, -6.8292, -14.5383, -11.5638, -6.3326, 11.5293, -1.8434, -10.0013, -7.6150], [-10.7384, -13.1179, -10.1837, -13.7700, -10.0186, -11.7335, -13.3411, -10.0188, -13.4235, -9.9381, -10.4252, -13.1281, -8.2022, -10.4326, -11.5542, -14.1549, -10.7546, -13.4691, -8.2744, -11.4324, -13.3773, -9.8284, -14.5825, -8.7471, -14.7050, -8.0364, -11.3627, -6.4638, -11.7031, -14.3446, -9.9425, -8.0088]], # noqa: E231
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
target_end_logits = torch.tensor(
|
||||
[[-12.4131, -8.5959, -15.7163, -11.1524, -15.9913, -12.2038, -7.8902, -16.0296, -12.164, -16.5017, -13.3332, -6.9488, -15.7756, -13.8506, -11.0779, -9.2893, -15.0426, -10.1963, -17.3292, -12.2945, -11.5337, -16.4514, -9.1564, -17.5001, -9.1562, -16.2971, -13.3199, -7.5724, -5.1175, 7.2168, -10.3804, -11.9873], [-10.8654, -14.9967, -11.4144, -16.9189, -14.2673, -9.7068, -15.0182, -12.8846, -16.8716, -13.665, -10.3113, -15.1436, -14.9069, -13.3364, -11.2339, -16.0118, -11.8331, -17.0613, -13.8852, -12.4163, -16.8978, -10.7772, -17.2324, -10.6979, -16.9811, -10.3427, -9.497, -13.7104, -11.1107, -13.2936, -13.855, -14.1264]], # noqa: E231
|
||||
[[-12.1736, -8.8487, -14.8877, -11.6713, -15.1165, -12.2396, -7.6828, -15.4153, -12.2528, -14.3671, -12.3596, -7.4272, -14.9615, -13.6356, -11.7939, -9.9767, -14.8112, -8.9567, -15.8798, -11.5291, -9.4249, -14.7544, -7.9387, -16.2789, -8.9702, -15.3111, -11.5585, -7.9992, -4.1127, 10.3209, -8.3926, -10.2005], [-11.1375, -15.4027, -12.6861, -16.9884, -13.7093, -10.3560, -15.7228, -12.9290, -15.8519, -13.7953, -10.2460, -15.7198, -14.2078, -12.8477, -11.4861, -16.1017, -11.8900, -16.4488, -13.2959, -10.3980, -15.4874, -10.3539, -16.8263, -10.9973, -17.0344, -9.2751, -10.1196, -13.8907, -12.1025, -13.0628, -12.8530, -13.8173]], # noqa: E321
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: on
|
||||
@ -954,7 +944,7 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
target = torch.tensor(
|
||||
[[-0.045136, -0.068013, 0.12246, -0.01356, 0.018386, 0.025333, -0.0044439, -0.0030996, -0.064031, 0.0006439], [-0.045018, -0.067638, 0.12317, -0.013998, 0.019216, 0.025695, -0.0043705, -0.0031895, -0.063153, 0.00088899], [-0.045042, -0.067305, 0.1234, -0.014512, 0.020057, 0.026084, -0.004615, -0.0031728, -0.062442, 0.0010263], [-0.044589, -0.067655, 0.12416, -0.014287, 0.019416, 0.026065, -0.0050958, -0.002702, -0.063158, 0.0004827], [-0.044627, -0.067535, 0.1239, -0.014319, 0.019491, 0.026213, -0.0059482, -0.0025906, -0.063116, 0.00014669], [-0.044899, -0.067704, 0.12337, -0.014231, 0.019256, 0.026345, -0.0065565, -0.0022938, -0.063433, -0.00011409], [-0.045599, -0.067764, 0.12235, -0.014151, 0.019206, 0.026417, -0.0068965, -0.0024494, -0.063313, -4.4499e-06], [-0.045557, -0.068372, 0.12199, -0.013747, 0.017962, 0.026103, -0.0070607, -0.0023552, -0.06447, -0.00048756], [-0.045334, -0.068913, 0.1217, -0.013566, 0.01693, 0.025745, -0.006311, -0.0024903, -0.065575, -0.0006719], [-0.045171, -0.068726, 0.12164, -0.013688, 0.017139, 0.025629, -0.005213, -0.0029412, -0.065237, -0.00020669], [-0.044411, -0.069267, 0.12206, -0.013645, 0.016212, 0.025589, -0.0044121, -0.002972, -0.066277, -0.00067963], [-0.043487, -0.069792, 0.1232, -0.013663, 0.015303, 0.02613, -0.0036294, -0.0030616, -0.067483, -0.0012642], [-0.042622, -0.069287, 0.12469, -0.013936, 0.016204, 0.026474, -0.0040534, -0.0027365, -0.066994, -0.0014148], [-0.041879, -0.070031, 0.12593, -0.014047, 0.015082, 0.027751, -0.0040683, -0.0027189, -0.068985, -0.0027146]], # noqa: E231
|
||||
[[-0.129420, -0.164740, 0.042422, -0.336030, 0.094379, 0.033794, 0.384590, 0.229660, -0.196500, 0.108020], [-0.000154, -0.168800, 0.165820, -0.313670, 0.101240, 0.035145, 0.381880, 0.213730, -0.201080, 0.077443], [0.053754, -0.166350, 0.225520, -0.272900, 0.119670, 0.019987, 0.348670, 0.199190, -0.181600, 0.084640], [0.063636, -0.187110, 0.237010, -0.297380, 0.126300, 0.020025, 0.268490, 0.191820, -0.192300, 0.035077], [0.073893, -0.184790, 0.188870, -0.297860, 0.134280, 0.028972, 0.174650, 0.186890, -0.180530, 0.006851], [0.005253, -0.169360, 0.123100, -0.302550, 0.126930, 0.024188, 0.133410, 0.200600, -0.168210, -0.001006], [-0.093336, -0.175370, -0.004768, -0.333170, 0.114330, 0.034168, 0.120960, 0.203570, -0.162810, -0.005757], [-0.160210, -0.169310, -0.049064, -0.331950, 0.115730, 0.027062, 0.143600, 0.205310, -0.144580, 0.026746], [-0.193200, -0.156820, -0.079422, -0.351600, 0.106450, 0.032174, 0.245690, 0.210250, -0.173480, 0.043914], [-0.167980, -0.153050, -0.059764, -0.357890,0.103910, 0.031481, 0.334190, 0.208960,-0.178180, 0.072165], [-0.136990, -0.156950, -0.012099, -0.353140,0.096996, 0.025864, 0.376340, 0.216050, -0.171820, 0.089963], [-0.041143, -0.167060, 0.079754, -0.353220, 0.093247, 0.019867, 0.385810, 0.214340, -0.191800, 0.065946],[0.040373, -0.158610, 0.152570, -0.312930, 0.110590, 0.012282, 0.345270, 0.204040, -0.176500, 0.064972], [0.043762, -0.166450, 0.179500, -0.317930, 0.117280, -0.004040, 0.304490, 0.201380, -0.182780, 0.044000]], # noqa: E231
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: on
|
||||
|
@ -15,7 +15,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import BigBirdConfig, is_flax_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
@ -221,17 +221,3 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
return
|
||||
else:
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@unittest.skip(
|
||||
reason="Current Pytorch implementation has bug with random attention -> it always uses it not matter if we are in eval/train mode"
|
||||
)
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@unittest.skip(
|
||||
reason="Current Pytorch implementation has bug with random attention -> it always uses it not matter if we are in eval/train mode"
|
||||
)
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user