mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix tapas issue (#12063)
* Fix scatter function to be compatible with torch-scatter 2.7.0 * Allow test again
This commit is contained in:
parent
e56e3140dd
commit
70f88eeccc
@ -1697,9 +1697,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
|
||||
|
||||
segment_means = scatter(
|
||||
src=flat_values,
|
||||
index=flat_index.indices.type(torch.long),
|
||||
index=flat_index.indices.long(),
|
||||
dim=0,
|
||||
dim_size=flat_index.num_segments,
|
||||
dim_size=int(flat_index.num_segments),
|
||||
reduce=segment_reduce_fn,
|
||||
)
|
||||
|
||||
|
@ -1044,7 +1044,6 @@ class TapasUtilitiesTest(unittest.TestCase):
|
||||
# We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
|
||||
np.testing.assert_array_equal(maximum.numpy(), [2, 3])
|
||||
|
||||
@unittest.skip("Fix me I'm failing on CI")
|
||||
def test_reduce_sum_vectorized(self):
|
||||
values = torch.as_tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]])
|
||||
index = IndexMap(indices=torch.as_tensor([0, 0, 1]), num_segments=2, batch_dims=0)
|
||||
|
Loading…
Reference in New Issue
Block a user