diff --git a/src/transformers/models/eomt/modeling_eomt.py b/src/transformers/models/eomt/modeling_eomt.py index 03dbb78a4ea..410be067357 100644 --- a/src/transformers/models/eomt/modeling_eomt.py +++ b/src/transformers/models/eomt/modeling_eomt.py @@ -1133,7 +1133,6 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel): if idx == self.num_hidden_layers - self.config.num_blocks: query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device) - print(query.device, hidden_states.device) hidden_states = torch.cat((query, hidden_states), dim=1) if idx >= self.num_hidden_layers - self.config.num_blocks and ( diff --git a/src/transformers/models/eomt/modular_eomt.py b/src/transformers/models/eomt/modular_eomt.py index 9a825dc5259..dd0b3bde521 100644 --- a/src/transformers/models/eomt/modular_eomt.py +++ b/src/transformers/models/eomt/modular_eomt.py @@ -508,7 +508,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul all_hidden_states += (hidden_states,) if idx == self.num_hidden_layers - self.config.num_blocks: - query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1) + query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device) hidden_states = torch.cat((query, hidden_states), dim=1) if idx >= self.num_hidden_layers - self.config.num_blocks and (