mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[janus] Fix failing tests on mi3XX (#38426)
* Fix multiple devices error on Janus * Fix AttributeError on Janus BOI token * Initialize lm first in Janus to get correct device map * Added expectations for Janus test_model_generate_images * Fixed JanusVisionEncoderLayer being split across devices * Code formatting * Adding modeling file * Reverted changes out of scope for this PR
This commit is contained in:
parent
78d771c3c2
commit
037acf1d10
@ -58,7 +58,7 @@ class JanusPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = JanusConfig
|
config_class = JanusConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["LlamaDecoderLayer"]
|
_no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
|
||||||
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
@ -1133,6 +1133,7 @@ class JanusModel(JanusPreTrainedModel):
|
|||||||
image_features = image_embeds.reshape(-1, embed_dim)
|
image_features = image_embeds.reshape(-1, embed_dim)
|
||||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
||||||
|
|
||||||
|
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
|
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
|
||||||
|
|
||||||
|
@ -379,7 +379,7 @@ class JanusPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = JanusConfig
|
config_class = JanusConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["LlamaDecoderLayer"]
|
_no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
|
||||||
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
@ -971,6 +971,7 @@ class JanusModel(JanusPreTrainedModel):
|
|||||||
image_features = image_embeds.reshape(-1, embed_dim)
|
image_features = image_embeds.reshape(-1, embed_dim)
|
||||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
||||||
|
|
||||||
|
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
|
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ from transformers import (
|
|||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@ -538,12 +539,21 @@ class JanusIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(out.shape[1] == 576)
|
self.assertTrue(out.shape[1] == 576)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
expected_tokens = torch.tensor([4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971,
|
expected_tokens = Expectations(
|
||||||
14985, 14834, 15438, 7548, 1820, 1465, 13529, 12761, 10503, 12761,
|
{
|
||||||
14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713,
|
("rocm", None): [10367, 1380, 4841, 15155, 1224, 16361, 15834, 13722, 15258, 8321, 10496, 14532, 8770,
|
||||||
14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
|
12353, 5481, 11484, 2585, 8587, 3201, 14292, 3356, 2037, 3077, 6107, 3758, 2572, 9376,
|
||||||
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676,
|
13219, 6007, 14292, 12696, 10666, 10046, 13483, 8282, 9101, 5208, 4260, 13886, 13335,
|
||||||
]).to(model.device)
|
6135, 2316, 15423, 311, 5460, 12218, 14172, 8583, 14577, 3648
|
||||||
|
],
|
||||||
|
("cuda", None): [4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548,
|
||||||
|
1820, 1465, 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146,
|
||||||
|
10417, 1951, 7713, 14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
|
||||||
|
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# Compare the first 50 generated tokens.
|
# Compare the first 50 generated tokens.
|
||||||
|
Loading…
Reference in New Issue
Block a user