[janus] Fix failing tests on mi3XX (#38426)

* Fix multiple devices error on Janus

* Fix AttributeError on Janus BOI token

* Initialize lm first in Janus to get correct device map

* Added expectations for Janus test_model_generate_images

* Fixed JanusVisionEncoderLayer being split across devices

* Code formatting

* Adding modeling file

* Reverted changes out of scope for this PR
This commit is contained in:
Rémi Ouazan 2025-06-04 09:38:10 +02:00 committed by GitHub
parent 78d771c3c2
commit 037acf1d10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 8 deletions

View File

@ -58,7 +58,7 @@ class JanusPreTrainedModel(PreTrainedModel):
config_class = JanusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_supports_flash_attn_2 = True
_supports_sdpa = True
@ -1133,6 +1133,7 @@ class JanusModel(JanusPreTrainedModel):
image_features = image_embeds.reshape(-1, embed_dim)
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)

View File

@ -379,7 +379,7 @@ class JanusPreTrainedModel(PreTrainedModel):
config_class = JanusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_supports_flash_attn_2 = True
_supports_sdpa = True
@ -971,6 +971,7 @@ class JanusModel(JanusPreTrainedModel):
image_features = image_embeds.reshape(-1, embed_dim)
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)

View File

@ -35,6 +35,7 @@ from transformers import (
from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES
from transformers.testing_utils import (
Expectations,
require_torch,
slow,
torch_device,
@ -538,12 +539,21 @@ class JanusIntegrationTest(unittest.TestCase):
self.assertTrue(out.shape[1] == 576)
# fmt: off
expected_tokens = torch.tensor([4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971,
14985, 14834, 15438, 7548, 1820, 1465, 13529, 12761, 10503, 12761,
14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713,
14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676,
]).to(model.device)
expected_tokens = Expectations(
{
("rocm", None): [10367, 1380, 4841, 15155, 1224, 16361, 15834, 13722, 15258, 8321, 10496, 14532, 8770,
12353, 5481, 11484, 2585, 8587, 3201, 14292, 3356, 2037, 3077, 6107, 3758, 2572, 9376,
13219, 6007, 14292, 12696, 10666, 10046, 13483, 8282, 9101, 5208, 4260, 13886, 13335,
6135, 2316, 15423, 311, 5460, 12218, 14172, 8583, 14577, 3648
],
("cuda", None): [4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548,
1820, 1465, 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146,
10417, 1951, 7713, 14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676
],
}
)
expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)
# fmt: on
# Compare the first 50 generated tokens.