Fixes device mismatch

This commit is contained in:
yaswant19 2025-07-02 00:29:05 +05:30
parent 06737dd807
commit e0a0381631
2 changed files with 1 additions and 2 deletions

View File

@ -1133,7 +1133,6 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
if idx == self.num_hidden_layers - self.config.num_blocks:
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
print(query.device, hidden_states.device)
hidden_states = torch.cat((query, hidden_states), dim=1)
if idx >= self.num_hidden_layers - self.config.num_blocks and (

View File

@ -508,7 +508,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
all_hidden_states += (hidden_states,)
if idx == self.num_hidden_layers - self.config.num_blocks:
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1)
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
hidden_states = torch.cat((query, hidden_states), dim=1)
if idx >= self.num_hidden_layers - self.config.num_blocks and (