Fix device issue for tapas (with as_tensor) (#37551)

* fix 1

* fix 2

* fix 3

* fix 4

* fix 5

* fix 6

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-04-16 16:02:53 +02:00 committed by GitHub
parent b33edf1b9b
commit 0577cae808
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1563,7 +1563,7 @@ class IndexMap:
batch dimensions. Segments in different batch elements are always distinct even if they have the same
index.
"""
self.indices = torch.as_tensor(indices)
self.indices = torch.as_tensor(indices, device=indices.device)
self.num_segments = torch.as_tensor(num_segments, device=indices.device)
self.batch_dims = batch_dims
@ -1693,11 +1693,14 @@ def range_index_map(batch_shape, num_segments, name="range_index_map"):
Returns:
(`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
"""
device = num_segments.device if torch.is_tensor(num_segments) else "cpu"
batch_shape = torch.as_tensor(
batch_shape, dtype=torch.long
batch_shape, dtype=torch.long, device=device
) # create a rank 1 tensor vector containing batch_shape (e.g. [2])
assert len(batch_shape.size()) == 1
num_segments = torch.as_tensor(num_segments) # create a rank 0 tensor (scalar) containing num_segments (e.g. 64)
num_segments = torch.as_tensor(
num_segments, device=device
) # create a rank 0 tensor (scalar) containing num_segments (e.g. 64)
assert len(num_segments.size()) == 0
indices = torch.arange(
@ -1711,7 +1714,7 @@ def range_index_map(batch_shape, num_segments, name="range_index_map"):
new_shape = [int(x) for x in new_tensor.tolist()]
indices = indices.view(new_shape)
multiples = torch.cat([batch_shape, torch.as_tensor([1])], dim=0)
multiples = torch.cat([batch_shape, torch.as_tensor([1], device=device)], dim=0)
indices = indices.repeat(multiples.tolist())
# equivalent (in Numpy:)
# indices = torch.as_tensor(np.tile(indices.numpy(), multiples.tolist()))
@ -1752,12 +1755,13 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
dim=0, index=flat_index.indices.long(), src=flat_values.float(), reduce=segment_reduce_fn, include_self=False
)
device = index.num_segments.device
# Unflatten the values.
new_shape = torch.cat(
[
torch.as_tensor(index.batch_shape(), dtype=torch.long),
torch.as_tensor([index.num_segments], dtype=torch.long),
torch.as_tensor(vector_shape, dtype=torch.long),
torch.as_tensor(index.batch_shape(), dtype=torch.long, device=device),
torch.as_tensor([index.num_segments], dtype=torch.long, device=device),
torch.as_tensor(vector_shape, dtype=torch.long, device=device),
],
dim=0,
)