Fix config + attn_implementation in AutoModelForCausalLM.from_pretrained (#30299)

* Update modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py
This commit is contained in:
hoshi-hiyouga 2024-04-20 00:45:53 +08:00 committed by GitHub
parent b1cd48740e
commit 21c912e79c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 1 deletions

View File

@ -3146,7 +3146,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config = copy.deepcopy(config)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
if kwarg_attn_imp is not None:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs

View File

@ -427,6 +427,44 @@ class ModelUtilsTest(TestCasePlus):
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
def test_model_from_pretrained_attn_implementation(self):
# test that the model can be instantiated with attn_implementation of either
# 1. explicit from_pretrained's attn_implementation argument
# 2. explicit from_pretrained's attn_implementation argument with a config argument
attn_implementation_available = ["eager"]
if is_torch_sdpa_available():
attn_implementation_available.append("sdpa")
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
"flash_attention_2": "MistralFlashAttention2",
}
for requested_attn_implementation in attn_implementation_available:
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, attn_implementation=requested_attn_implementation
)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
config = AutoConfig.from_pretrained(TINY_MISTRAL)
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, config=config, attn_implementation=requested_attn_implementation
)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config)