mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix config + attn_implementation in AutoModelForCausalLM.from_pretrained (#30299)
* Update modeling_utils.py * Update test_modeling_utils.py * Update test_modeling_utils.py * Update test_modeling_utils.py
This commit is contained in:
parent
b1cd48740e
commit
21c912e79c
@ -3146,7 +3146,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
|
||||
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
|
||||
if kwarg_attn_imp is not None:
|
||||
config._attn_implementation = kwarg_attn_imp
|
||||
model_kwargs = kwargs
|
||||
|
||||
|
@ -427,6 +427,44 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
def test_model_from_pretrained_attn_implementation(self):
|
||||
# test that the model can be instantiated with attn_implementation of either
|
||||
# 1. explicit from_pretrained's attn_implementation argument
|
||||
# 2. explicit from_pretrained's attn_implementation argument with a config argument
|
||||
attn_implementation_available = ["eager"]
|
||||
if is_torch_sdpa_available():
|
||||
attn_implementation_available.append("sdpa")
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
attn_implementation_available.append("flash_attention_2")
|
||||
|
||||
mistral_attention_classes = {
|
||||
"eager": "MistralAttention",
|
||||
"sdpa": "MistralSdpaAttention",
|
||||
"flash_attention_2": "MistralFlashAttention2",
|
||||
}
|
||||
for requested_attn_implementation in attn_implementation_available:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_MISTRAL, attn_implementation=requested_attn_implementation
|
||||
)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
self.assertEqual(
|
||||
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_MISTRAL, config=config, attn_implementation=requested_attn_implementation
|
||||
)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
self.assertEqual(
|
||||
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||
)
|
||||
|
||||
def test_no_super_init_config_and_model(self):
|
||||
config = NoSuperInitConfig(attribute=32)
|
||||
model = NoSuperInitModel(config)
|
||||
|
Loading…
Reference in New Issue
Block a user