mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix: Fixed failing tests in tests/utils/test_add_new_model_like.py
(#32678)
* Fixed failing tests in tests/utils/test_add_new_model_like.py * Fixed formatting using ruff. * Small nit.
This commit is contained in:
parent
a22ff36e0e
commit
df323476a3
@ -61,6 +61,7 @@ VIT_MODEL_FILES = {
|
||||
"src/transformers/models/vit/convert_vit_timm_to_pytorch.py",
|
||||
"src/transformers/models/vit/feature_extraction_vit.py",
|
||||
"src/transformers/models/vit/image_processing_vit.py",
|
||||
"src/transformers/models/vit/image_processing_vit_fast.py",
|
||||
"src/transformers/models/vit/modeling_vit.py",
|
||||
"src/transformers/models/vit/modeling_tf_vit.py",
|
||||
"src/transformers/models/vit/modeling_flax_vit.py",
|
||||
@ -662,7 +663,13 @@ NEW_BERT_CONSTANT = "value"
|
||||
def test_retrieve_model_classes(self):
|
||||
gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2").items()}
|
||||
expected_gpt_classes = {
|
||||
"pt": {"GPT2ForTokenClassification", "GPT2Model", "GPT2LMHeadModel", "GPT2ForSequenceClassification"},
|
||||
"pt": {
|
||||
"GPT2ForTokenClassification",
|
||||
"GPT2Model",
|
||||
"GPT2LMHeadModel",
|
||||
"GPT2ForSequenceClassification",
|
||||
"GPT2ForQuestionAnswering",
|
||||
},
|
||||
"tf": {"TFGPT2Model", "TFGPT2ForSequenceClassification", "TFGPT2LMHeadModel"},
|
||||
"flax": {"FlaxGPT2Model", "FlaxGPT2LMHeadModel"},
|
||||
}
|
||||
@ -836,7 +843,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
]
|
||||
expected_model_classes = {
|
||||
"pt": set(wav2vec2_classes),
|
||||
"tf": {f"TF{m}" for m in wav2vec2_classes[:1]},
|
||||
"tf": {f"TF{m}" for m in [wav2vec2_classes[0], wav2vec2_classes[-2]]},
|
||||
"flax": {f"Flax{m}" for m in wav2vec2_classes[:2]},
|
||||
}
|
||||
|
||||
@ -870,7 +877,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
self.assertEqual(wav2vec2_model_patterns.model_type, "wav2vec2")
|
||||
self.assertEqual(wav2vec2_model_patterns.model_lower_cased, "wav2vec2")
|
||||
self.assertEqual(wav2vec2_model_patterns.model_camel_cased, "Wav2Vec2")
|
||||
self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV_2_VEC_2")
|
||||
self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV2VEC2")
|
||||
self.assertEqual(wav2vec2_model_patterns.config_class, "Wav2Vec2Config")
|
||||
self.assertEqual(wav2vec2_model_patterns.feature_extractor_class, "Wav2Vec2FeatureExtractor")
|
||||
self.assertEqual(wav2vec2_model_patterns.processor_class, "Wav2Vec2Processor")
|
||||
|
Loading…
Reference in New Issue
Block a user