mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
[janus] Fix failing tests on mi3XX (#38426)
* Fix multiple devices error on Janus * Fix AttributeError on Janus BOI token * Initialize lm first in Janus to get correct device map * Added expectations for Janus test_model_generate_images * Fixed JanusVisionEncoderLayer being split across devices * Code formatting * Adding modeling file * Reverted changes out of scope for this PR
This commit is contained in:
parent
78d771c3c2
commit
037acf1d10
@ -58,7 +58,7 @@ class JanusPreTrainedModel(PreTrainedModel):
|
||||
config_class = JanusConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlamaDecoderLayer"]
|
||||
_no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
@ -1133,6 +1133,7 @@ class JanusModel(JanusPreTrainedModel):
|
||||
image_features = image_embeds.reshape(-1, embed_dim)
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
||||
|
||||
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
|
||||
|
||||
|
@ -379,7 +379,7 @@ class JanusPreTrainedModel(PreTrainedModel):
|
||||
config_class = JanusConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlamaDecoderLayer"]
|
||||
_no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
@ -971,6 +971,7 @@ class JanusModel(JanusPreTrainedModel):
|
||||
image_features = image_embeds.reshape(-1, embed_dim)
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
||||
|
||||
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
|
||||
|
||||
|
@ -35,6 +35,7 @@ from transformers import (
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
@ -538,12 +539,21 @@ class JanusIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(out.shape[1] == 576)
|
||||
|
||||
# fmt: off
|
||||
expected_tokens = torch.tensor([4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971,
|
||||
14985, 14834, 15438, 7548, 1820, 1465, 13529, 12761, 10503, 12761,
|
||||
14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713,
|
||||
14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
|
||||
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676,
|
||||
]).to(model.device)
|
||||
expected_tokens = Expectations(
|
||||
{
|
||||
("rocm", None): [10367, 1380, 4841, 15155, 1224, 16361, 15834, 13722, 15258, 8321, 10496, 14532, 8770,
|
||||
12353, 5481, 11484, 2585, 8587, 3201, 14292, 3356, 2037, 3077, 6107, 3758, 2572, 9376,
|
||||
13219, 6007, 14292, 12696, 10666, 10046, 13483, 8282, 9101, 5208, 4260, 13886, 13335,
|
||||
6135, 2316, 15423, 311, 5460, 12218, 14172, 8583, 14577, 3648
|
||||
],
|
||||
("cuda", None): [4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548,
|
||||
1820, 1465, 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146,
|
||||
10417, 1951, 7713, 14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
|
||||
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676
|
||||
],
|
||||
}
|
||||
)
|
||||
expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)
|
||||
# fmt: on
|
||||
|
||||
# Compare the first 50 generated tokens.
|
||||
|
Loading…
Reference in New Issue
Block a user