fix device issue (#20227)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-11-15 15:21:16 +01:00 committed by GitHub
parent d3d5fa3e85
commit 6ed6ed29b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1784,9 +1784,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
# changed "view" by "reshape" in the following line
flat_values = values.reshape(flattened_shape.tolist())
out = torch.zeros(int(flat_index.num_segments), dtype=flat_values.dtype)
out = torch.zeros(int(flat_index.num_segments), dtype=torch.float, device=flat_values.device)
segment_means = out.scatter_reduce(
dim=0, index=flat_index.indices.long(), src=flat_values, reduce=segment_reduce_fn, include_self=False
dim=0, index=flat_index.indices.long(), src=flat_values.float(), reduce=segment_reduce_fn, include_self=False
)
# Unflatten the values.
@ -1799,7 +1799,7 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
dim=0,
)
output_values = segment_means.clone().view(new_shape.tolist())
output_values = segment_means.clone().view(new_shape.tolist()).to(values.dtype)
output_index = range_index_map(index.batch_shape(), index.num_segments)
return output_values, output_index