mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
fix device issue (#20227)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
d3d5fa3e85
commit
6ed6ed29b1
@ -1784,9 +1784,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
|
||||
# changed "view" by "reshape" in the following line
|
||||
flat_values = values.reshape(flattened_shape.tolist())
|
||||
|
||||
out = torch.zeros(int(flat_index.num_segments), dtype=flat_values.dtype)
|
||||
out = torch.zeros(int(flat_index.num_segments), dtype=torch.float, device=flat_values.device)
|
||||
segment_means = out.scatter_reduce(
|
||||
dim=0, index=flat_index.indices.long(), src=flat_values, reduce=segment_reduce_fn, include_self=False
|
||||
dim=0, index=flat_index.indices.long(), src=flat_values.float(), reduce=segment_reduce_fn, include_self=False
|
||||
)
|
||||
|
||||
# Unflatten the values.
|
||||
@ -1799,7 +1799,7 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
output_values = segment_means.clone().view(new_shape.tolist())
|
||||
output_values = segment_means.clone().view(new_shape.tolist()).to(values.dtype)
|
||||
output_index = range_index_map(index.batch_shape(), index.num_segments)
|
||||
return output_values, output_index
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user