diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 959cdc6856f..a526ce5d7af 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -58,7 +58,7 @@ class JanusPreTrainedModel(PreTrainedModel): config_class = JanusConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] + _no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True _supports_sdpa = True @@ -1133,6 +1133,7 @@ class JanusModel(JanusPreTrainedModel): image_features = image_embeds.reshape(-1, embed_dim) image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim) + image_attention_mask = image_attention_mask.to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index a6965687781..0d484ffb0c0 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -379,7 +379,7 @@ class JanusPreTrainedModel(PreTrainedModel): config_class = JanusConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] + _no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True _supports_sdpa = True @@ -971,6 +971,7 @@ class JanusModel(JanusPreTrainedModel): image_features = image_embeds.reshape(-1, embed_dim) image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim) + image_attention_mask = image_attention_mask.to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 48cf7ebc2f5..2729c0718c3 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -35,6 +35,7 @@ from transformers import ( from transformers.models.auto import get_values from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES from transformers.testing_utils import ( + Expectations, require_torch, slow, torch_device, @@ -538,12 +539,21 @@ class JanusIntegrationTest(unittest.TestCase): self.assertTrue(out.shape[1] == 576) # fmt: off - expected_tokens = torch.tensor([4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, - 14985, 14834, 15438, 7548, 1820, 1465, 13529, 12761, 10503, 12761, - 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, - 14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, - 1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676, - ]).to(model.device) + expected_tokens = Expectations( + { + ("rocm", None): [10367, 1380, 4841, 15155, 1224, 16361, 15834, 13722, 15258, 8321, 10496, 14532, 8770, + 12353, 5481, 11484, 2585, 8587, 3201, 14292, 3356, 2037, 3077, 6107, 3758, 2572, 9376, + 13219, 6007, 14292, 12696, 10666, 10046, 13483, 8282, 9101, 5208, 4260, 13886, 13335, + 6135, 2316, 15423, 311, 5460, 12218, 14172, 8583, 14577, 3648 + ], + ("cuda", None): [4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, + 1820, 1465, 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, + 10417, 1951, 7713, 14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, + 1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676 + ], + } + ) + expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device) # fmt: on # Compare the first 50 generated tokens.