Fix tapas issue (#12063)

* Fix scatter function to be compatible with torch-scatter 2.7.0

* Allow test again
This commit is contained in:
NielsRogge 2021-06-08 11:22:31 +02:00 committed by GitHub
parent e56e3140dd
commit 70f88eeccc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 3 deletions

View File

@ -1697,9 +1697,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
segment_means = scatter(
src=flat_values,
index=flat_index.indices.type(torch.long),
index=flat_index.indices.long(),
dim=0,
dim_size=flat_index.num_segments,
dim_size=int(flat_index.num_segments),
reduce=segment_reduce_fn,
)

View File

@ -1044,7 +1044,6 @@ class TapasUtilitiesTest(unittest.TestCase):
# We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
np.testing.assert_array_equal(maximum.numpy(), [2, 3])
@unittest.skip("Fix me I'm failing on CI")
def test_reduce_sum_vectorized(self):
values = torch.as_tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]])
index = IndexMap(indices=torch.as_tensor([0, 0, 1]), num_segments=2, batch_dims=0)