mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix device issue for tapas (with as_tensor
) (#37551)
* fix 1 * fix 2 * fix 3 * fix 4 * fix 5 * fix 6 --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
b33edf1b9b
commit
0577cae808
@ -1563,7 +1563,7 @@ class IndexMap:
|
||||
batch dimensions. Segments in different batch elements are always distinct even if they have the same
|
||||
index.
|
||||
"""
|
||||
self.indices = torch.as_tensor(indices)
|
||||
self.indices = torch.as_tensor(indices, device=indices.device)
|
||||
self.num_segments = torch.as_tensor(num_segments, device=indices.device)
|
||||
self.batch_dims = batch_dims
|
||||
|
||||
@ -1693,11 +1693,14 @@ def range_index_map(batch_shape, num_segments, name="range_index_map"):
|
||||
Returns:
|
||||
(`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
|
||||
"""
|
||||
device = num_segments.device if torch.is_tensor(num_segments) else "cpu"
|
||||
batch_shape = torch.as_tensor(
|
||||
batch_shape, dtype=torch.long
|
||||
batch_shape, dtype=torch.long, device=device
|
||||
) # create a rank 1 tensor vector containing batch_shape (e.g. [2])
|
||||
assert len(batch_shape.size()) == 1
|
||||
num_segments = torch.as_tensor(num_segments) # create a rank 0 tensor (scalar) containing num_segments (e.g. 64)
|
||||
num_segments = torch.as_tensor(
|
||||
num_segments, device=device
|
||||
) # create a rank 0 tensor (scalar) containing num_segments (e.g. 64)
|
||||
assert len(num_segments.size()) == 0
|
||||
|
||||
indices = torch.arange(
|
||||
@ -1711,7 +1714,7 @@ def range_index_map(batch_shape, num_segments, name="range_index_map"):
|
||||
new_shape = [int(x) for x in new_tensor.tolist()]
|
||||
indices = indices.view(new_shape)
|
||||
|
||||
multiples = torch.cat([batch_shape, torch.as_tensor([1])], dim=0)
|
||||
multiples = torch.cat([batch_shape, torch.as_tensor([1], device=device)], dim=0)
|
||||
indices = indices.repeat(multiples.tolist())
|
||||
# equivalent (in Numpy:)
|
||||
# indices = torch.as_tensor(np.tile(indices.numpy(), multiples.tolist()))
|
||||
@ -1752,12 +1755,13 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
|
||||
dim=0, index=flat_index.indices.long(), src=flat_values.float(), reduce=segment_reduce_fn, include_self=False
|
||||
)
|
||||
|
||||
device = index.num_segments.device
|
||||
# Unflatten the values.
|
||||
new_shape = torch.cat(
|
||||
[
|
||||
torch.as_tensor(index.batch_shape(), dtype=torch.long),
|
||||
torch.as_tensor([index.num_segments], dtype=torch.long),
|
||||
torch.as_tensor(vector_shape, dtype=torch.long),
|
||||
torch.as_tensor(index.batch_shape(), dtype=torch.long, device=device),
|
||||
torch.as_tensor([index.num_segments], dtype=torch.long, device=device),
|
||||
torch.as_tensor(vector_shape, dtype=torch.long, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user