This commit is contained in:
yaswant19 2025-07-02 00:19:17 +05:30
parent 8667b4a57d
commit 06737dd807

View File

@ -1132,7 +1132,8 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
all_hidden_states += (hidden_states,)
if idx == self.num_hidden_layers - self.config.num_blocks:
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1)
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
print(query.device, hidden_states.device)
hidden_states = torch.cat((query, hidden_states), dim=1)
if idx >= self.num_hidden_layers - self.config.num_blocks and (