mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix TVLT (torch device issue) (#21710)
* fix tvlt ci * fix tvlt ci * fix tvlt ci --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
4c6346cc3e
commit
03aaac3502
@ -153,7 +153,7 @@ def generate_pixel_mask_noise(pixel_values, pixel_mask=None, mask_ratio=0.75):
|
||||
"""Generate noise for audio masking."""
|
||||
|
||||
batch_size, seq_len = pixel_values.shape[:2]
|
||||
noise = torch.rand((batch_size, seq_len)) # noise in [0, 1]
|
||||
noise = torch.rand((batch_size, seq_len), device=pixel_values.device) # noise in [0, 1]
|
||||
len_keep = int(seq_len * (1 - mask_ratio))
|
||||
return noise, len_keep
|
||||
|
||||
@ -165,10 +165,13 @@ def generate_audio_mask_noise(audio_values, audio_mask=None, mask_ratio=0.75, ma
|
||||
if mask_type == "frame-level":
|
||||
num_time_patches = seq_len // freq_len
|
||||
noise = (
|
||||
torch.rand(batch_size, num_time_patches).unsqueeze(-1).repeat(1, 1, freq_len).view(batch_size, seq_len)
|
||||
torch.rand(batch_size, num_time_patches, device=audio_values.device)
|
||||
.unsqueeze(-1)
|
||||
.repeat(1, 1, freq_len)
|
||||
.view(batch_size, seq_len)
|
||||
) # noise in [0, 1]
|
||||
elif mask_type == "patch-level":
|
||||
noise = torch.rand(batch_size, seq_len) # noise in [0, 1]
|
||||
noise = torch.rand(batch_size, seq_len, device=audio_values.device) # noise in [0, 1]
|
||||
len_keep = int(seq_len * (1 - mask_ratio))
|
||||
return noise, len_keep
|
||||
|
||||
|
@ -590,7 +590,7 @@ class TvltModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_last_hidden_state_slice = torch.tensor([[-0.0186, -0.0691], [0.0242, -0.0398]])
|
||||
expected_last_hidden_state_slice = torch.tensor([[-0.0186, -0.0691], [0.0242, -0.0398]], device=torch_device)
|
||||
self.assertTrue(
|
||||
torch.allclose(outputs.last_hidden_state[:, :2, :2], expected_last_hidden_state_slice, atol=1e-4)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user