mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00

* first raw version of the bark integration * working code on small models with single run * add converting script from suno weights 2 hf * many changes * correct past_kv output * working implementation for inference * update the converting script according to the architecture changes * add a working end-to-end inference code * remove some comments and make small changes * remove unecessary comment * add docstrings and ensure no unecessary intermediary output during audio generation * remove done TODOs * make style + add config docstrings * modification for batch inference support on the whole model * add details to .generation_audio method * add copyright * convert EncodecModel from original library to transformers implementation * add two class in order to facilitate model and sub-models loading from the hub * add support of loading the whole model * add BarkProcessor * correct modeling according to processor output * Add proper __init__ and auto support * Add up-to-date copyright/license message * add relative import instead of absolute * cleaner head_dim computation * small comment removal or changes * more verbose LayerNorm init method * specify eps for clearer comprehension * more verbose variable naming in the MLP module * remove unecessary BarkBlock parameter * clearer code in the forward pass of the BarkBlock * remove _initialize_modules method for cleaner code * Remove unnecessary methods from sub-models * move code to remove unnecessary function * rename a variable for clarity and change an assert * move code and change variable name for clarity * remove unnecessary asserts * correct small bug * correct a comment * change variable names for clarity * remove asserts * change import from absolute to relative * correct small error due to comma missing + correct import * Add attribute Bark config * add first version of tests * update attention_map * add tie_weights and resize_token_embeddings for fineModel * correct getting attention_mask in generate_text_semantic * remove Bark inference trick * leave more choices in barkProcessor * remove _no_split_modules * fixe error in forward of block and introduce clearer notations * correct converting script with last changes * make style + add draft bark.mdx * correct BarkModelTest::test_generate_text_semantic * add Bark in main README * add dummy_pt_objects for Bark * add missing models in the main init * correct test_decoder_model_past_with_large_inputs * disable torchscript test * change docstring of BarkProcessor * Add test_processor_bark * make style * correct copyrights * add bark.mdx + make style, quality and consistency * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Remove unnecessary test method * simply logic of a test * Only check first ids for slow audio generation * split full end-to-end generation tests * remove unneccessary comment * change submodel names for clearer naming * remove ModuleDict from modeling_bark * combine two if statements * ensure that an edge misued won't happen * modify variable name * move code snippet to the right place (coarse instead of semantic) * change BarkSemanticModule -> BarkSemanticModel * align BarkProcessor with transformers paradigm * correct BarkProcessor tests with last commit changes * change _validate_voice_preset to an instance method instead of a class method * tie_weights already called with post_init * add codec_model config to configuration * update bark modeling tests with recent BarkProcessor changes * remove SubModelPretrainedModel + change speakers embeddings prompt type in BarkModel * change absolute imports to relative * remove TODO * change docstrings * add examples to docs and docstrings * make style * uses BatchFeature in BarkProcessor insteads of dict * continue improving docstrings and docs + make style * correct docstrings examples * more comprehensible speaker_embeddings load/Save * rename speaker_embeddings_dict -> speaker_embeddings * correct bark.mdx + add bark to documentation_tests * correct docstrings configuration_bark * integrate last nit suggestions * integrate BarkGeneration configs * make style * remove bark tests from documentation_tests.txt because timeout - tested manually * add proper generation config initialization * small bark.mdx documentation changes * rename bark.mdx -> bark.md * add torch.no_grad behind BarkModel.generate_audio() * replace assert by ValueError in convert_suno_to_hf.py * integrate a series of short comments from reviewer * move SemanticLogitsProcessors and remove .detach() from Bark docs and docstrings * actually remove SemanticLogitsProcessor from modeling_bark.oy * BarkProcessor returns a single output instead of tuple + correct docstrings * make style + correct bug * add initializer_range to BarkConfig + correct slow modeling tests * add .clone() to history_prompt.coarse_prompt to avoid modifying input array * Making sure no extra "`" are present * remove extra characters in modeling_bark.py * Correct output if history_prompt is None * remove TODOs * remove ravel comment * completing generation_configuration_bark.py docstrings * change docstrings - number of audio codebooks instead of Encodec codebooks * change 'bias' docstrings in configuration_bark.py * format code * rename BarkModel.generate_audio -> BarkModel.generate_speech * modify AutoConfig instead of EncodecConfig in BarkConfig * correct AutoConfig wrong init * refactor BarkModel and sub-models generate_coarse, generate_fine, generate_text_semantic * remove SemanticLogitsProcessor and replace it with SuppressTokensLogitsProcessor * move nb_codebook related config arguments to BarkFineConfig * rename bark.mdx -> bark.md * correcting BarkModelConfig from_pretrained + remove keys_to_ignore * correct bark.md with correct hub path * correct code bug in bark.md * correct list tokens_to_suppress * modify Processor to load nested speaker embeddings in a safer way * correct batch sampling in BarkFineModel.generate_fine * Apply suggestions from code review Small docstrings correction and code improvements Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * give more details about num_layers in docstrings * correct indentation mistake * correct submodelconfig order of docstring variables * put audio models in alphabetical order in utils/check_repo.my * remove useless line from test_modeling_bark.py * makes BarkCoarseModelTest inherits from (ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) instead of BarkSemanticModelTest * make a Tester class for each sub-model instead of inheriting * add test_resize_embeddings=True for Bark sub-models * add Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._split_heads * remove 'Copied fom Bark' comment * remove unneccessary comment * change np.min -> min in modeling_bark.py * refactored all custom layers to have Bark prefix * add attention_mask as an argument of generate_text_semantic * refactor sub-models start docstrings to have more precise config class definition * move _tied_weights_keys overriding * add docstrings to generate_xxx in modeling_bark.py * add loading whole BarkModel to convert_suno_to_hf * refactor attribute and variable names * make style convert_suno * update bark checkpoints * remove never entered if statement * move bark_modeling docstrings after BarkPretrainedModel class definition * refactor modeling_bark.py: kv -> key_values * small nits - code refactoring and removing unecessary lines from _init_weights * nits - replace inplace method by variable assigning * remove *optional* when necessary * remove some lines in generate_speech * add default value for optional parameter * Refactor preprocess_histories_before_coarse -> preprocess_histories Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * correct usage after refactoring * refactor Bark's generate_xxx -> generate and modify docstrings and tests accordingly * update docstrings python in configuration_bark.py * add bark files in utils/documentation_test.txt * correct docstrings python snippet * add the ability to use parameters in the form of e.g coarse_temperature * add semantic_max_new_tokens in python snippet in docstrings for quicker generation * Reformate sub-models kwargs in BakModel.generate Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * correct kwargs in BarkModel.generate * correct attention_mask kwarg in BarkModel.generate * add tests for sub-models args in BarkModel.generate and correct BarkFineModel.test_generate_fp16 * enrich BarkModel.generate docstrings with a description of how to use the kwargs --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1139 lines
47 KiB
Python
1139 lines
47 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
import os
|
|
import re
|
|
import sys
|
|
import warnings
|
|
from collections import OrderedDict
|
|
from difflib import get_close_matches
|
|
from pathlib import Path
|
|
|
|
from transformers import is_flax_available, is_tf_available, is_torch_available
|
|
from transformers.models.auto import get_values
|
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
|
|
from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
|
|
from transformers.models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
|
|
from transformers.models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
|
|
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
|
|
from transformers.utils import ENV_VARS_TRUE_VALUES, direct_transformers_import
|
|
|
|
|
|
# All paths are set with the intent you should run this script from the root of the repo with the command
|
|
# python utils/check_repo.py
|
|
PATH_TO_TRANSFORMERS = "src/transformers"
|
|
PATH_TO_TESTS = "tests"
|
|
PATH_TO_DOC = "docs/source/en"
|
|
|
|
# Update this list with models that are supposed to be private.
|
|
PRIVATE_MODELS = [
|
|
"AltRobertaModel",
|
|
"DPRSpanPredictor",
|
|
"LongT5Stack",
|
|
"RealmBertModel",
|
|
"T5Stack",
|
|
"MT5Stack",
|
|
"UMT5Stack",
|
|
"SwitchTransformersStack",
|
|
"TFDPRSpanPredictor",
|
|
"MaskFormerSwinModel",
|
|
"MaskFormerSwinPreTrainedModel",
|
|
"BridgeTowerTextModel",
|
|
"BridgeTowerVisionModel",
|
|
]
|
|
|
|
# Update this list for models that are not tested with a comment explaining the reason it should not be.
|
|
# Being in this list is an exception and should **not** be the rule.
|
|
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
|
# models to ignore for not tested
|
|
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
|
|
"NllbMoeDecoder",
|
|
"NllbMoeEncoder",
|
|
"UMT5EncoderModel", # Building part of bigger (tested) model.
|
|
"LlamaDecoder", # Building part of bigger (tested) model.
|
|
"Blip2QFormerModel", # Building part of bigger (tested) model.
|
|
"DetaEncoder", # Building part of bigger (tested) model.
|
|
"DetaDecoder", # Building part of bigger (tested) model.
|
|
"ErnieMForInformationExtraction",
|
|
"GraphormerEncoder", # Building part of bigger (tested) model.
|
|
"GraphormerDecoderHead", # Building part of bigger (tested) model.
|
|
"CLIPSegDecoder", # Building part of bigger (tested) model.
|
|
"TableTransformerEncoder", # Building part of bigger (tested) model.
|
|
"TableTransformerDecoder", # Building part of bigger (tested) model.
|
|
"TimeSeriesTransformerEncoder", # Building part of bigger (tested) model.
|
|
"TimeSeriesTransformerDecoder", # Building part of bigger (tested) model.
|
|
"InformerEncoder", # Building part of bigger (tested) model.
|
|
"InformerDecoder", # Building part of bigger (tested) model.
|
|
"AutoformerEncoder", # Building part of bigger (tested) model.
|
|
"AutoformerDecoder", # Building part of bigger (tested) model.
|
|
"JukeboxVQVAE", # Building part of bigger (tested) model.
|
|
"JukeboxPrior", # Building part of bigger (tested) model.
|
|
"DeformableDetrEncoder", # Building part of bigger (tested) model.
|
|
"DeformableDetrDecoder", # Building part of bigger (tested) model.
|
|
"OPTDecoder", # Building part of bigger (tested) model.
|
|
"FlaxWhisperDecoder", # Building part of bigger (tested) model.
|
|
"FlaxWhisperEncoder", # Building part of bigger (tested) model.
|
|
"WhisperDecoder", # Building part of bigger (tested) model.
|
|
"WhisperEncoder", # Building part of bigger (tested) model.
|
|
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
|
|
"SegformerDecodeHead", # Building part of bigger (tested) model.
|
|
"PLBartEncoder", # Building part of bigger (tested) model.
|
|
"PLBartDecoder", # Building part of bigger (tested) model.
|
|
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
|
|
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
|
|
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
|
|
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
|
|
"DetrEncoder", # Building part of bigger (tested) model.
|
|
"DetrDecoder", # Building part of bigger (tested) model.
|
|
"DetrDecoderWrapper", # Building part of bigger (tested) model.
|
|
"ConditionalDetrEncoder", # Building part of bigger (tested) model.
|
|
"ConditionalDetrDecoder", # Building part of bigger (tested) model.
|
|
"M2M100Encoder", # Building part of bigger (tested) model.
|
|
"M2M100Decoder", # Building part of bigger (tested) model.
|
|
"MCTCTEncoder", # Building part of bigger (tested) model.
|
|
"MgpstrModel", # Building part of bigger (tested) model.
|
|
"Speech2TextEncoder", # Building part of bigger (tested) model.
|
|
"Speech2TextDecoder", # Building part of bigger (tested) model.
|
|
"LEDEncoder", # Building part of bigger (tested) model.
|
|
"LEDDecoder", # Building part of bigger (tested) model.
|
|
"BartDecoderWrapper", # Building part of bigger (tested) model.
|
|
"BartEncoder", # Building part of bigger (tested) model.
|
|
"BertLMHeadModel", # Needs to be setup as decoder.
|
|
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
|
|
"BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
|
|
"BlenderbotEncoder", # Building part of bigger (tested) model.
|
|
"BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
|
|
"MBartEncoder", # Building part of bigger (tested) model.
|
|
"MBartDecoderWrapper", # Building part of bigger (tested) model.
|
|
"MegatronBertLMHeadModel", # Building part of bigger (tested) model.
|
|
"MegatronBertEncoder", # Building part of bigger (tested) model.
|
|
"MegatronBertDecoder", # Building part of bigger (tested) model.
|
|
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
|
|
"MusicgenDecoder", # Building part of bigger (tested) model.
|
|
"MvpDecoderWrapper", # Building part of bigger (tested) model.
|
|
"MvpEncoder", # Building part of bigger (tested) model.
|
|
"PegasusEncoder", # Building part of bigger (tested) model.
|
|
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
|
|
"PegasusXEncoder", # Building part of bigger (tested) model.
|
|
"PegasusXDecoder", # Building part of bigger (tested) model.
|
|
"PegasusXDecoderWrapper", # Building part of bigger (tested) model.
|
|
"DPREncoder", # Building part of bigger (tested) model.
|
|
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
|
"RealmBertModel", # Building part of bigger (tested) model.
|
|
"RealmReader", # Not regular model.
|
|
"RealmScorer", # Not regular model.
|
|
"RealmForOpenQA", # Not regular model.
|
|
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
|
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
|
|
"TFDPREncoder", # Building part of bigger (tested) model.
|
|
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
|
"TFRobertaForMultipleChoice", # TODO: fix
|
|
"TFRobertaPreLayerNormForMultipleChoice", # TODO: fix
|
|
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
|
"TFWhisperEncoder", # Building part of bigger (tested) model.
|
|
"TFWhisperDecoder", # Building part of bigger (tested) model.
|
|
"SeparableConv1D", # Building part of bigger (tested) model.
|
|
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
|
|
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
|
|
"OPTDecoderWrapper",
|
|
"TFSegformerDecodeHead", # Not a regular model.
|
|
"AltRobertaModel", # Building part of bigger (tested) model.
|
|
"BlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
|
|
"TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
|
|
"BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model.
|
|
"BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model.
|
|
"SpeechT5Decoder", # Building part of bigger (tested) model.
|
|
"SpeechT5DecoderWithoutPrenet", # Building part of bigger (tested) model.
|
|
"SpeechT5DecoderWithSpeechPrenet", # Building part of bigger (tested) model.
|
|
"SpeechT5DecoderWithTextPrenet", # Building part of bigger (tested) model.
|
|
"SpeechT5Encoder", # Building part of bigger (tested) model.
|
|
"SpeechT5EncoderWithoutPrenet", # Building part of bigger (tested) model.
|
|
"SpeechT5EncoderWithSpeechPrenet", # Building part of bigger (tested) model.
|
|
"SpeechT5EncoderWithTextPrenet", # Building part of bigger (tested) model.
|
|
"SpeechT5SpeechDecoder", # Building part of bigger (tested) model.
|
|
"SpeechT5SpeechEncoder", # Building part of bigger (tested) model.
|
|
"SpeechT5TextDecoder", # Building part of bigger (tested) model.
|
|
"SpeechT5TextEncoder", # Building part of bigger (tested) model.
|
|
"BarkCausalModel", # Building part of bigger (tested) model.
|
|
"BarkModel", # Does not have a forward signature - generation tested with integration tests
|
|
]
|
|
|
|
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
|
|
# trigger the common tests.
|
|
TEST_FILES_WITH_NO_COMMON_TESTS = [
|
|
"models/decision_transformer/test_modeling_decision_transformer.py",
|
|
"models/camembert/test_modeling_camembert.py",
|
|
"models/mt5/test_modeling_flax_mt5.py",
|
|
"models/mbart/test_modeling_mbart.py",
|
|
"models/mt5/test_modeling_mt5.py",
|
|
"models/pegasus/test_modeling_pegasus.py",
|
|
"models/camembert/test_modeling_tf_camembert.py",
|
|
"models/mt5/test_modeling_tf_mt5.py",
|
|
"models/xlm_roberta/test_modeling_tf_xlm_roberta.py",
|
|
"models/xlm_roberta/test_modeling_flax_xlm_roberta.py",
|
|
"models/xlm_prophetnet/test_modeling_xlm_prophetnet.py",
|
|
"models/xlm_roberta/test_modeling_xlm_roberta.py",
|
|
"models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
|
|
"models/vision_text_dual_encoder/test_modeling_tf_vision_text_dual_encoder.py",
|
|
"models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py",
|
|
"models/decision_transformer/test_modeling_decision_transformer.py",
|
|
"models/bark/test_modeling_bark.py",
|
|
]
|
|
|
|
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
|
# should **not** be the rule.
|
|
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|
# models to ignore for model xxx mapping
|
|
"AlignTextModel",
|
|
"AlignVisionModel",
|
|
"ClapTextModel",
|
|
"ClapTextModelWithProjection",
|
|
"ClapAudioModel",
|
|
"ClapAudioModelWithProjection",
|
|
"Blip2ForConditionalGeneration",
|
|
"Blip2QFormerModel",
|
|
"Blip2VisionModel",
|
|
"ErnieMForInformationExtraction",
|
|
"GitVisionModel",
|
|
"GraphormerModel",
|
|
"GraphormerForGraphClassification",
|
|
"BlipForConditionalGeneration",
|
|
"BlipForImageTextRetrieval",
|
|
"BlipForQuestionAnswering",
|
|
"BlipVisionModel",
|
|
"BlipTextLMHeadModel",
|
|
"BlipTextModel",
|
|
"TFBlipForConditionalGeneration",
|
|
"TFBlipForImageTextRetrieval",
|
|
"TFBlipForQuestionAnswering",
|
|
"TFBlipVisionModel",
|
|
"TFBlipTextLMHeadModel",
|
|
"TFBlipTextModel",
|
|
"Swin2SRForImageSuperResolution",
|
|
"BridgeTowerForImageAndTextRetrieval",
|
|
"BridgeTowerForMaskedLM",
|
|
"BridgeTowerForContrastiveLearning",
|
|
"CLIPSegForImageSegmentation",
|
|
"CLIPSegVisionModel",
|
|
"CLIPSegTextModel",
|
|
"EsmForProteinFolding",
|
|
"GPTSanJapaneseModel",
|
|
"TimeSeriesTransformerForPrediction",
|
|
"InformerForPrediction",
|
|
"AutoformerForPrediction",
|
|
"JukeboxVQVAE",
|
|
"JukeboxPrior",
|
|
"PegasusXEncoder",
|
|
"PegasusXDecoder",
|
|
"PegasusXDecoderWrapper",
|
|
"PegasusXEncoder",
|
|
"PegasusXDecoder",
|
|
"PegasusXDecoderWrapper",
|
|
"SamModel",
|
|
"DPTForDepthEstimation",
|
|
"DecisionTransformerGPT2Model",
|
|
"GLPNForDepthEstimation",
|
|
"ViltForImagesAndTextClassification",
|
|
"ViltForImageAndTextRetrieval",
|
|
"ViltForTokenClassification",
|
|
"ViltForMaskedLM",
|
|
"XGLMEncoder",
|
|
"XGLMDecoder",
|
|
"XGLMDecoderWrapper",
|
|
"PerceiverForMultimodalAutoencoding",
|
|
"PerceiverForOpticalFlow",
|
|
"SegformerDecodeHead",
|
|
"TFSegformerDecodeHead",
|
|
"FlaxBeitForMaskedImageModeling",
|
|
"PLBartEncoder",
|
|
"PLBartDecoder",
|
|
"PLBartDecoderWrapper",
|
|
"BeitForMaskedImageModeling",
|
|
"ChineseCLIPTextModel",
|
|
"ChineseCLIPVisionModel",
|
|
"CLIPTextModel",
|
|
"CLIPTextModelWithProjection",
|
|
"CLIPVisionModel",
|
|
"CLIPVisionModelWithProjection",
|
|
"GroupViTTextModel",
|
|
"GroupViTVisionModel",
|
|
"TFCLIPTextModel",
|
|
"TFCLIPVisionModel",
|
|
"TFGroupViTTextModel",
|
|
"TFGroupViTVisionModel",
|
|
"FlaxCLIPTextModel",
|
|
"FlaxCLIPVisionModel",
|
|
"FlaxWav2Vec2ForCTC",
|
|
"DetrForSegmentation",
|
|
"Pix2StructVisionModel",
|
|
"Pix2StructTextModel",
|
|
"Pix2StructForConditionalGeneration",
|
|
"ConditionalDetrForSegmentation",
|
|
"DPRReader",
|
|
"FlaubertForQuestionAnswering",
|
|
"FlavaImageCodebook",
|
|
"FlavaTextModel",
|
|
"FlavaImageModel",
|
|
"FlavaMultimodalModel",
|
|
"GPT2DoubleHeadsModel",
|
|
"GPTSw3DoubleHeadsModel",
|
|
"InstructBlipVisionModel",
|
|
"InstructBlipQFormerModel",
|
|
"LayoutLMForQuestionAnswering",
|
|
"LukeForMaskedLM",
|
|
"LukeForEntityClassification",
|
|
"LukeForEntityPairClassification",
|
|
"LukeForEntitySpanClassification",
|
|
"MgpstrModel",
|
|
"OpenAIGPTDoubleHeadsModel",
|
|
"OwlViTTextModel",
|
|
"OwlViTVisionModel",
|
|
"OwlViTForObjectDetection",
|
|
"RagModel",
|
|
"RagSequenceForGeneration",
|
|
"RagTokenForGeneration",
|
|
"RealmEmbedder",
|
|
"RealmForOpenQA",
|
|
"RealmScorer",
|
|
"RealmReader",
|
|
"TFDPRReader",
|
|
"TFGPT2DoubleHeadsModel",
|
|
"TFLayoutLMForQuestionAnswering",
|
|
"TFOpenAIGPTDoubleHeadsModel",
|
|
"TFRagModel",
|
|
"TFRagSequenceForGeneration",
|
|
"TFRagTokenForGeneration",
|
|
"Wav2Vec2ForCTC",
|
|
"HubertForCTC",
|
|
"SEWForCTC",
|
|
"SEWDForCTC",
|
|
"XLMForQuestionAnswering",
|
|
"XLNetForQuestionAnswering",
|
|
"SeparableConv1D",
|
|
"VisualBertForRegionToPhraseAlignment",
|
|
"VisualBertForVisualReasoning",
|
|
"VisualBertForQuestionAnswering",
|
|
"VisualBertForMultipleChoice",
|
|
"TFWav2Vec2ForCTC",
|
|
"TFHubertForCTC",
|
|
"XCLIPVisionModel",
|
|
"XCLIPTextModel",
|
|
"AltCLIPTextModel",
|
|
"AltCLIPVisionModel",
|
|
"AltRobertaModel",
|
|
"TvltForAudioVisualClassification",
|
|
"BarkCausalModel",
|
|
"BarkCoarseModel",
|
|
"BarkFineModel",
|
|
"BarkSemanticModel",
|
|
"MusicgenModel",
|
|
"MusicgenForConditionalGeneration",
|
|
"SpeechT5ForSpeechToSpeech",
|
|
"SpeechT5ForTextToSpeech",
|
|
"SpeechT5HifiGan",
|
|
]
|
|
|
|
# DO NOT edit this list!
|
|
# (The corresponding pytorch objects should never be in the main `__init__`, but it's too late to remove)
|
|
OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
|
|
"FlaxBertLayer",
|
|
"FlaxBigBirdLayer",
|
|
"FlaxRoFormerLayer",
|
|
"TFBertLayer",
|
|
"TFLxmertEncoder",
|
|
"TFLxmertXLayer",
|
|
"TFMPNetLayer",
|
|
"TFMobileBertLayer",
|
|
"TFSegformerLayer",
|
|
"TFViTMAELayer",
|
|
]
|
|
|
|
# Update this list for models that have multiple model types for the same
|
|
# model doc
|
|
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
|
|
[
|
|
("data2vec-text", "data2vec"),
|
|
("data2vec-audio", "data2vec"),
|
|
("data2vec-vision", "data2vec"),
|
|
("donut-swin", "donut"),
|
|
]
|
|
)
|
|
|
|
|
|
# This is to make sure the transformers module imported is the one in the repo.
|
|
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
|
|
|
|
|
def check_missing_backends():
|
|
missing_backends = []
|
|
if not is_torch_available():
|
|
missing_backends.append("PyTorch")
|
|
if not is_tf_available():
|
|
missing_backends.append("TensorFlow")
|
|
if not is_flax_available():
|
|
missing_backends.append("Flax")
|
|
if len(missing_backends) > 0:
|
|
missing = ", ".join(missing_backends)
|
|
if os.getenv("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
|
|
raise Exception(
|
|
"Full repo consistency checks require all backends to be installed (with `pip install -e .[dev]` in the "
|
|
f"Transformers repo, the following are missing: {missing}."
|
|
)
|
|
else:
|
|
warnings.warn(
|
|
"Full repo consistency checks require all backends to be installed (with `pip install -e .[dev]` in the "
|
|
f"Transformers repo, the following are missing: {missing}. While it's probably fine as long as you "
|
|
"didn't make any change in one of those backends modeling files, you should probably execute the "
|
|
"command above to be on the safe side."
|
|
)
|
|
|
|
|
|
def check_model_list():
|
|
"""Check the model list inside the transformers library."""
|
|
# Get the models from the directory structure of `src/transformers/models/`
|
|
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
|
|
_models = []
|
|
for model in os.listdir(models_dir):
|
|
if model == "deprecated":
|
|
continue
|
|
model_dir = os.path.join(models_dir, model)
|
|
if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir):
|
|
_models.append(model)
|
|
|
|
# Get the models from the directory structure of `src/transformers/models/`
|
|
models = [model for model in dir(transformers.models) if not model.startswith("__")]
|
|
|
|
missing_models = sorted(set(_models).difference(models))
|
|
if missing_models:
|
|
raise Exception(
|
|
f"The following models should be included in {models_dir}/__init__.py: {','.join(missing_models)}."
|
|
)
|
|
|
|
|
|
# If some modeling modules should be ignored for all checks, they should be added in the nested list
|
|
# _ignore_modules of this function.
|
|
def get_model_modules():
|
|
"""Get the model modules inside the transformers library."""
|
|
_ignore_modules = [
|
|
"modeling_auto",
|
|
"modeling_encoder_decoder",
|
|
"modeling_marian",
|
|
"modeling_mmbt",
|
|
"modeling_outputs",
|
|
"modeling_retribert",
|
|
"modeling_utils",
|
|
"modeling_flax_auto",
|
|
"modeling_flax_encoder_decoder",
|
|
"modeling_flax_utils",
|
|
"modeling_speech_encoder_decoder",
|
|
"modeling_flax_speech_encoder_decoder",
|
|
"modeling_flax_vision_encoder_decoder",
|
|
"modeling_timm_backbone",
|
|
"modeling_transfo_xl_utilities",
|
|
"modeling_tf_auto",
|
|
"modeling_tf_encoder_decoder",
|
|
"modeling_tf_outputs",
|
|
"modeling_tf_pytorch_utils",
|
|
"modeling_tf_utils",
|
|
"modeling_tf_transfo_xl_utilities",
|
|
"modeling_tf_vision_encoder_decoder",
|
|
"modeling_vision_encoder_decoder",
|
|
]
|
|
modules = []
|
|
for model in dir(transformers.models):
|
|
if model == "deprecated":
|
|
continue
|
|
# There are some magic dunder attributes in the dir, we ignore them
|
|
if not model.startswith("__"):
|
|
model_module = getattr(transformers.models, model)
|
|
for submodule in dir(model_module):
|
|
if submodule.startswith("modeling") and submodule not in _ignore_modules:
|
|
modeling_module = getattr(model_module, submodule)
|
|
if inspect.ismodule(modeling_module):
|
|
modules.append(modeling_module)
|
|
return modules
|
|
|
|
|
|
def get_models(module, include_pretrained=False):
|
|
"""Get the objects in module that are models."""
|
|
models = []
|
|
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
|
for attr_name in dir(module):
|
|
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
|
|
continue
|
|
attr = getattr(module, attr_name)
|
|
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
|
|
models.append((attr_name, attr))
|
|
return models
|
|
|
|
|
|
def is_a_private_model(model):
|
|
"""Returns True if the model should not be in the main init."""
|
|
if model in PRIVATE_MODELS:
|
|
return True
|
|
|
|
# Wrapper, Encoder and Decoder are all privates
|
|
if model.endswith("Wrapper"):
|
|
return True
|
|
if model.endswith("Encoder"):
|
|
return True
|
|
if model.endswith("Decoder"):
|
|
return True
|
|
if model.endswith("Prenet"):
|
|
return True
|
|
return False
|
|
|
|
|
|
def check_models_are_in_init():
|
|
"""Checks all models defined in the library are in the main init."""
|
|
models_not_in_init = []
|
|
dir_transformers = dir(transformers)
|
|
for module in get_model_modules():
|
|
models_not_in_init += [
|
|
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
|
|
]
|
|
|
|
# Remove private models
|
|
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
|
|
if len(models_not_in_init) > 0:
|
|
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
|
|
|
|
|
|
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
|
|
# nested list _ignore_files of this function.
|
|
def get_model_test_files():
|
|
"""Get the model test files.
|
|
|
|
The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be
|
|
considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files.
|
|
"""
|
|
|
|
_ignore_files = [
|
|
"test_modeling_common",
|
|
"test_modeling_encoder_decoder",
|
|
"test_modeling_flax_encoder_decoder",
|
|
"test_modeling_flax_speech_encoder_decoder",
|
|
"test_modeling_marian",
|
|
"test_modeling_tf_common",
|
|
"test_modeling_tf_encoder_decoder",
|
|
]
|
|
test_files = []
|
|
# Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models`
|
|
model_test_root = os.path.join(PATH_TO_TESTS, "models")
|
|
model_test_dirs = []
|
|
for x in os.listdir(model_test_root):
|
|
x = os.path.join(model_test_root, x)
|
|
if os.path.isdir(x):
|
|
model_test_dirs.append(x)
|
|
|
|
for target_dir in [PATH_TO_TESTS] + model_test_dirs:
|
|
for file_or_dir in os.listdir(target_dir):
|
|
path = os.path.join(target_dir, file_or_dir)
|
|
if os.path.isfile(path):
|
|
filename = os.path.split(path)[-1]
|
|
if "test_modeling" in filename and os.path.splitext(filename)[0] not in _ignore_files:
|
|
file = os.path.join(*path.split(os.sep)[1:])
|
|
test_files.append(file)
|
|
|
|
return test_files
|
|
|
|
|
|
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
|
|
# for the all_model_classes variable.
|
|
def find_tested_models(test_file):
|
|
"""Parse the content of test_file to detect what's in all_model_classes"""
|
|
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
|
|
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
|
|
content = f.read()
|
|
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
|
|
# Check with one less parenthesis as well
|
|
all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
|
|
if len(all_models) > 0:
|
|
model_tested = []
|
|
for entry in all_models:
|
|
for line in entry.split(","):
|
|
name = line.strip()
|
|
if len(name) > 0:
|
|
model_tested.append(name)
|
|
return model_tested
|
|
|
|
|
|
def check_models_are_tested(module, test_file):
|
|
"""Check models defined in module are tested in test_file."""
|
|
# XxxPreTrainedModel are not tested
|
|
defined_models = get_models(module)
|
|
tested_models = find_tested_models(test_file)
|
|
if tested_models is None:
|
|
if test_file.replace(os.path.sep, "/") in TEST_FILES_WITH_NO_COMMON_TESTS:
|
|
return
|
|
return [
|
|
f"{test_file} should define `all_model_classes` to apply common tests to the models it tests. "
|
|
+ "If this intentional, add the test filename to `TEST_FILES_WITH_NO_COMMON_TESTS` in the file "
|
|
+ "`utils/check_repo.py`."
|
|
]
|
|
failures = []
|
|
for model_name, _ in defined_models:
|
|
if model_name not in tested_models and model_name not in IGNORE_NON_TESTED:
|
|
failures.append(
|
|
f"{model_name} is defined in {module.__name__} but is not tested in "
|
|
+ f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file."
|
|
+ "If common tests should not applied to that model, add its name to `IGNORE_NON_TESTED`"
|
|
+ "in the file `utils/check_repo.py`."
|
|
)
|
|
return failures
|
|
|
|
|
|
def check_all_models_are_tested():
|
|
"""Check all models are properly tested."""
|
|
modules = get_model_modules()
|
|
test_files = get_model_test_files()
|
|
failures = []
|
|
for module in modules:
|
|
test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file]
|
|
if len(test_file) == 0:
|
|
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
|
|
elif len(test_file) > 1:
|
|
failures.append(f"{module.__name__} has several test files: {test_file}.")
|
|
else:
|
|
test_file = test_file[0]
|
|
new_failures = check_models_are_tested(module, test_file)
|
|
if new_failures is not None:
|
|
failures += new_failures
|
|
if len(failures) > 0:
|
|
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
|
|
|
|
|
def get_all_auto_configured_models():
|
|
"""Return the list of all models in at least one auto class."""
|
|
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
|
if is_torch_available():
|
|
for attr_name in dir(transformers.models.auto.modeling_auto):
|
|
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
|
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
|
if is_tf_available():
|
|
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
|
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
|
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
|
if is_flax_available():
|
|
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
|
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
|
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
|
return list(result)
|
|
|
|
|
|
def ignore_unautoclassed(model_name):
|
|
"""Rules to determine if `name` should be in an auto class."""
|
|
# Special white list
|
|
if model_name in IGNORE_NON_AUTO_CONFIGURED:
|
|
return True
|
|
# Encoder and Decoder should be ignored
|
|
if "Encoder" in model_name or "Decoder" in model_name:
|
|
return True
|
|
return False
|
|
|
|
|
|
def check_models_are_auto_configured(module, all_auto_models):
|
|
"""Check models defined in module are each in an auto class."""
|
|
defined_models = get_models(module)
|
|
failures = []
|
|
for model_name, _ in defined_models:
|
|
if model_name not in all_auto_models and not ignore_unautoclassed(model_name):
|
|
failures.append(
|
|
f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. "
|
|
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
|
|
"`utils/check_repo.py`."
|
|
)
|
|
return failures
|
|
|
|
|
|
def check_all_models_are_auto_configured():
|
|
"""Check all models are each in an auto class."""
|
|
check_missing_backends()
|
|
modules = get_model_modules()
|
|
all_auto_models = get_all_auto_configured_models()
|
|
failures = []
|
|
for module in modules:
|
|
new_failures = check_models_are_auto_configured(module, all_auto_models)
|
|
if new_failures is not None:
|
|
failures += new_failures
|
|
if len(failures) > 0:
|
|
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
|
|
|
|
|
def check_all_auto_object_names_being_defined():
|
|
"""Check all names defined in auto (name) mappings exist in the library."""
|
|
check_missing_backends()
|
|
|
|
failures = []
|
|
mappings_to_check = {
|
|
"TOKENIZER_MAPPING_NAMES": TOKENIZER_MAPPING_NAMES,
|
|
"IMAGE_PROCESSOR_MAPPING_NAMES": IMAGE_PROCESSOR_MAPPING_NAMES,
|
|
"FEATURE_EXTRACTOR_MAPPING_NAMES": FEATURE_EXTRACTOR_MAPPING_NAMES,
|
|
"PROCESSOR_MAPPING_NAMES": PROCESSOR_MAPPING_NAMES,
|
|
}
|
|
|
|
# Each auto modeling files contains multiple mappings. Let's get them in a dynamic way.
|
|
for module_name in ["modeling_auto", "modeling_tf_auto", "modeling_flax_auto"]:
|
|
module = getattr(transformers.models.auto, module_name, None)
|
|
if module is None:
|
|
continue
|
|
# all mappings in a single auto modeling file
|
|
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")]
|
|
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
|
|
|
for name, mapping in mappings_to_check.items():
|
|
for model_type, class_names in mapping.items():
|
|
if not isinstance(class_names, tuple):
|
|
class_names = (class_names,)
|
|
for class_name in class_names:
|
|
if class_name is None:
|
|
continue
|
|
# dummy object is accepted
|
|
if not hasattr(transformers, class_name):
|
|
# If the class name is in a model name mapping, let's not check if there is a definition in any modeling
|
|
# module, if it's a private model defined in this file.
|
|
if name.endswith("MODEL_MAPPING_NAMES") and is_a_private_model(class_name):
|
|
continue
|
|
failures.append(
|
|
f"`{class_name}` appears in the mapping `{name}` but it is not defined in the library."
|
|
)
|
|
if len(failures) > 0:
|
|
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
|
|
|
|
|
def check_all_auto_mapping_names_in_config_mapping_names():
|
|
"""Check all keys defined in auto mappings (mappings of names) appear in `CONFIG_MAPPING_NAMES`."""
|
|
check_missing_backends()
|
|
|
|
failures = []
|
|
# `TOKENIZER_PROCESSOR_MAPPING_NAMES` and `AutoTokenizer` is special, and don't need to follow the rule.
|
|
mappings_to_check = {
|
|
"IMAGE_PROCESSOR_MAPPING_NAMES": IMAGE_PROCESSOR_MAPPING_NAMES,
|
|
"FEATURE_EXTRACTOR_MAPPING_NAMES": FEATURE_EXTRACTOR_MAPPING_NAMES,
|
|
"PROCESSOR_MAPPING_NAMES": PROCESSOR_MAPPING_NAMES,
|
|
}
|
|
|
|
# Each auto modeling files contains multiple mappings. Let's get them in a dynamic way.
|
|
for module_name in ["modeling_auto", "modeling_tf_auto", "modeling_flax_auto"]:
|
|
module = getattr(transformers.models.auto, module_name, None)
|
|
if module is None:
|
|
continue
|
|
# all mappings in a single auto modeling file
|
|
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")]
|
|
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
|
|
|
for name, mapping in mappings_to_check.items():
|
|
for model_type, class_names in mapping.items():
|
|
if model_type not in CONFIG_MAPPING_NAMES:
|
|
failures.append(
|
|
f"`{model_type}` appears in the mapping `{name}` but it is not defined in the keys of "
|
|
"`CONFIG_MAPPING_NAMES`."
|
|
)
|
|
if len(failures) > 0:
|
|
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
|
|
|
|
|
def check_all_auto_mappings_importable():
|
|
"""Check all auto mappings could be imported."""
|
|
check_missing_backends()
|
|
|
|
failures = []
|
|
mappings_to_check = {}
|
|
# Each auto modeling files contains multiple mappings. Let's get them in a dynamic way.
|
|
for module_name in ["modeling_auto", "modeling_tf_auto", "modeling_flax_auto"]:
|
|
module = getattr(transformers.models.auto, module_name, None)
|
|
if module is None:
|
|
continue
|
|
# all mappings in a single auto modeling file
|
|
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")]
|
|
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
|
|
|
for name, _ in mappings_to_check.items():
|
|
name = name.replace("_MAPPING_NAMES", "_MAPPING")
|
|
if not hasattr(transformers, name):
|
|
failures.append(f"`{name}`")
|
|
if len(failures) > 0:
|
|
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
|
|
|
|
|
def check_objects_being_equally_in_main_init():
|
|
"""Check if an object is in the main __init__ if its counterpart in PyTorch is."""
|
|
attrs = dir(transformers)
|
|
|
|
failures = []
|
|
for attr in attrs:
|
|
obj = getattr(transformers, attr)
|
|
if hasattr(obj, "__module__"):
|
|
module_path = obj.__module__
|
|
if "models.deprecated" in module_path:
|
|
continue
|
|
module_name = module_path.split(".")[-1]
|
|
module_dir = ".".join(module_path.split(".")[:-1])
|
|
if (
|
|
module_name.startswith("modeling_")
|
|
and not module_name.startswith("modeling_tf_")
|
|
and not module_name.startswith("modeling_flax_")
|
|
):
|
|
parent_module = sys.modules[module_dir]
|
|
|
|
frameworks = []
|
|
if is_tf_available():
|
|
frameworks.append("TF")
|
|
if is_flax_available():
|
|
frameworks.append("Flax")
|
|
|
|
for framework in frameworks:
|
|
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
|
|
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
|
|
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
|
|
other_module = getattr(parent_module, other_module_name)
|
|
if hasattr(other_module, f"{framework}{attr}"):
|
|
if not hasattr(transformers, f"{framework}{attr}"):
|
|
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
|
failures.append(f"{framework}{attr}")
|
|
if hasattr(other_module, f"{framework}_{attr}"):
|
|
if not hasattr(transformers, f"{framework}_{attr}"):
|
|
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
|
failures.append(f"{framework}_{attr}")
|
|
if len(failures) > 0:
|
|
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
|
|
|
|
|
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
|
|
|
|
|
def check_decorator_order(filename):
|
|
"""Check that in the test file `filename` the slow decorator is always last."""
|
|
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
|
lines = f.readlines()
|
|
decorator_before = None
|
|
errors = []
|
|
for i, line in enumerate(lines):
|
|
search = _re_decorator.search(line)
|
|
if search is not None:
|
|
decorator_name = search.groups()[0]
|
|
if decorator_before is not None and decorator_name.startswith("parameterized"):
|
|
errors.append(i)
|
|
decorator_before = decorator_name
|
|
elif decorator_before is not None:
|
|
decorator_before = None
|
|
return errors
|
|
|
|
|
|
def check_all_decorator_order():
|
|
"""Check that in all test files, the slow decorator is always last."""
|
|
errors = []
|
|
for fname in os.listdir(PATH_TO_TESTS):
|
|
if fname.endswith(".py"):
|
|
filename = os.path.join(PATH_TO_TESTS, fname)
|
|
new_errors = check_decorator_order(filename)
|
|
errors += [f"- {filename}, line {i}" for i in new_errors]
|
|
if len(errors) > 0:
|
|
msg = "\n".join(errors)
|
|
raise ValueError(
|
|
"The parameterized decorator (and its variants) should always be first, but this is not the case in the"
|
|
f" following files:\n{msg}"
|
|
)
|
|
|
|
|
|
def find_all_documented_objects():
|
|
"""Parse the content of all doc files to detect which classes and functions it documents"""
|
|
documented_obj = []
|
|
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
|
|
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
|
|
content = f.read()
|
|
raw_doc_objs = re.findall(r"(?:autoclass|autofunction):: transformers.(\S+)\s+", content)
|
|
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
|
|
for doc_file in Path(PATH_TO_DOC).glob("**/*.md"):
|
|
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
|
|
content = f.read()
|
|
raw_doc_objs = re.findall(r"\[\[autodoc\]\]\s+(\S+)\s+", content)
|
|
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
|
|
return documented_obj
|
|
|
|
|
|
# One good reason for not being documented is to be deprecated. Put in this list deprecated objects.
|
|
DEPRECATED_OBJECTS = [
|
|
"AutoModelWithLMHead",
|
|
"BartPretrainedModel",
|
|
"DataCollator",
|
|
"DataCollatorForSOP",
|
|
"GlueDataset",
|
|
"GlueDataTrainingArguments",
|
|
"LineByLineTextDataset",
|
|
"LineByLineWithRefDataset",
|
|
"LineByLineWithSOPTextDataset",
|
|
"PretrainedBartModel",
|
|
"PretrainedFSMTModel",
|
|
"SingleSentenceClassificationProcessor",
|
|
"SquadDataTrainingArguments",
|
|
"SquadDataset",
|
|
"SquadExample",
|
|
"SquadFeatures",
|
|
"SquadV1Processor",
|
|
"SquadV2Processor",
|
|
"TFAutoModelWithLMHead",
|
|
"TFBartPretrainedModel",
|
|
"TextDataset",
|
|
"TextDatasetForNextSentencePrediction",
|
|
"Wav2Vec2ForMaskedLM",
|
|
"Wav2Vec2Tokenizer",
|
|
"glue_compute_metrics",
|
|
"glue_convert_examples_to_features",
|
|
"glue_output_modes",
|
|
"glue_processors",
|
|
"glue_tasks_num_labels",
|
|
"squad_convert_examples_to_features",
|
|
"xnli_compute_metrics",
|
|
"xnli_output_modes",
|
|
"xnli_processors",
|
|
"xnli_tasks_num_labels",
|
|
"TFTrainer",
|
|
"TFTrainingArguments",
|
|
]
|
|
|
|
# Exceptionally, some objects should not be documented after all rules passed.
|
|
# ONLY PUT SOMETHING IN THIS LIST AS A LAST RESORT!
|
|
UNDOCUMENTED_OBJECTS = [
|
|
"AddedToken", # This is a tokenizers class.
|
|
"BasicTokenizer", # Internal, should never have been in the main init.
|
|
"CharacterTokenizer", # Internal, should never have been in the main init.
|
|
"DPRPretrainedReader", # Like an Encoder.
|
|
"DummyObject", # Just picked by mistake sometimes.
|
|
"MecabTokenizer", # Internal, should never have been in the main init.
|
|
"ModelCard", # Internal type.
|
|
"SqueezeBertModule", # Internal building block (should have been called SqueezeBertLayer)
|
|
"TFDPRPretrainedReader", # Like an Encoder.
|
|
"TransfoXLCorpus", # Internal type.
|
|
"WordpieceTokenizer", # Internal, should never have been in the main init.
|
|
"absl", # External module
|
|
"add_end_docstrings", # Internal, should never have been in the main init.
|
|
"add_start_docstrings", # Internal, should never have been in the main init.
|
|
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
|
|
"logger", # Internal logger
|
|
"logging", # External module
|
|
"requires_backends", # Internal function
|
|
"AltRobertaModel", # Internal module
|
|
"FalconConfig", # TODO Matt Remove this and re-add the docs once TGI is ready
|
|
"FalconForCausalLM",
|
|
"FalconForQuestionAnswering",
|
|
"FalconForSequenceClassification",
|
|
"FalconForTokenClassification",
|
|
"FalconModel",
|
|
]
|
|
|
|
# This list should be empty. Objects in it should get their own doc page.
|
|
SHOULD_HAVE_THEIR_OWN_PAGE = [
|
|
# Benchmarks
|
|
"PyTorchBenchmark",
|
|
"PyTorchBenchmarkArguments",
|
|
"TensorFlowBenchmark",
|
|
"TensorFlowBenchmarkArguments",
|
|
"AutoBackbone",
|
|
"BitBackbone",
|
|
"ConvNextBackbone",
|
|
"ConvNextV2Backbone",
|
|
"DinatBackbone",
|
|
"FocalNetBackbone",
|
|
"MaskFormerSwinBackbone",
|
|
"MaskFormerSwinConfig",
|
|
"MaskFormerSwinModel",
|
|
"NatBackbone",
|
|
"ResNetBackbone",
|
|
"SwinBackbone",
|
|
"TimmBackbone",
|
|
"TimmBackboneConfig",
|
|
]
|
|
|
|
|
|
def ignore_undocumented(name):
|
|
"""Rules to determine if `name` should be undocumented."""
|
|
# NOT DOCUMENTED ON PURPOSE.
|
|
# Constants uppercase are not documented.
|
|
if name.isupper():
|
|
return True
|
|
# PreTrainedModels / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
|
|
if (
|
|
name.endswith("PreTrainedModel")
|
|
or name.endswith("Decoder")
|
|
or name.endswith("Encoder")
|
|
or name.endswith("Layer")
|
|
or name.endswith("Embeddings")
|
|
or name.endswith("Attention")
|
|
):
|
|
return True
|
|
# Submodules are not documented.
|
|
if os.path.isdir(os.path.join(PATH_TO_TRANSFORMERS, name)) or os.path.isfile(
|
|
os.path.join(PATH_TO_TRANSFORMERS, f"{name}.py")
|
|
):
|
|
return True
|
|
# All load functions are not documented.
|
|
if name.startswith("load_tf") or name.startswith("load_pytorch"):
|
|
return True
|
|
# is_xxx_available functions are not documented.
|
|
if name.startswith("is_") and name.endswith("_available"):
|
|
return True
|
|
# Deprecated objects are not documented.
|
|
if name in DEPRECATED_OBJECTS or name in UNDOCUMENTED_OBJECTS:
|
|
return True
|
|
# MMBT model does not really work.
|
|
if name.startswith("MMBT"):
|
|
return True
|
|
if name in SHOULD_HAVE_THEIR_OWN_PAGE:
|
|
return True
|
|
return False
|
|
|
|
|
|
def check_all_objects_are_documented():
|
|
"""Check all models are properly documented."""
|
|
documented_objs = find_all_documented_objects()
|
|
modules = transformers._modules
|
|
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
|
|
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
|
|
if len(undocumented_objs) > 0:
|
|
raise Exception(
|
|
"The following objects are in the public init so should be documented:\n - "
|
|
+ "\n - ".join(undocumented_objs)
|
|
)
|
|
check_docstrings_are_in_md()
|
|
check_model_type_doc_match()
|
|
|
|
|
|
def check_model_type_doc_match():
|
|
"""Check all doc pages have a corresponding model type."""
|
|
model_doc_folder = Path(PATH_TO_DOC) / "model_doc"
|
|
model_docs = [m.stem for m in model_doc_folder.glob("*.md")]
|
|
|
|
model_types = list(transformers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
|
|
model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types]
|
|
|
|
errors = []
|
|
for m in model_docs:
|
|
if m not in model_types and m != "auto":
|
|
close_matches = get_close_matches(m, model_types)
|
|
error_message = f"{m} is not a proper model identifier."
|
|
if len(close_matches) > 0:
|
|
close_matches = "/".join(close_matches)
|
|
error_message += f" Did you mean {close_matches}?"
|
|
errors.append(error_message)
|
|
|
|
if len(errors) > 0:
|
|
raise ValueError(
|
|
"Some model doc pages do not match any existing model type:\n"
|
|
+ "\n".join(errors)
|
|
+ "\nYou can add any missing model type to the `MODEL_NAMES_MAPPING` constant in "
|
|
"models/auto/configuration_auto.py."
|
|
)
|
|
|
|
|
|
# Re pattern to catch :obj:`xx`, :class:`xx`, :func:`xx` or :meth:`xx`.
|
|
_re_rst_special_words = re.compile(r":(?:obj|func|class|meth):`([^`]+)`")
|
|
# Re pattern to catch things between double backquotes.
|
|
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
|
|
# Re pattern to catch example introduction.
|
|
_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE)
|
|
|
|
|
|
def is_rst_docstring(docstring):
|
|
"""
|
|
Returns `True` if `docstring` is written in rst.
|
|
"""
|
|
if _re_rst_special_words.search(docstring) is not None:
|
|
return True
|
|
if _re_double_backquotes.search(docstring) is not None:
|
|
return True
|
|
if _re_rst_example.search(docstring) is not None:
|
|
return True
|
|
return False
|
|
|
|
|
|
def check_docstrings_are_in_md():
|
|
"""Check all docstrings are in md"""
|
|
files_with_rst = []
|
|
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
|
|
with open(file, encoding="utf-8") as f:
|
|
code = f.read()
|
|
docstrings = code.split('"""')
|
|
|
|
for idx, docstring in enumerate(docstrings):
|
|
if idx % 2 == 0 or not is_rst_docstring(docstring):
|
|
continue
|
|
files_with_rst.append(file)
|
|
break
|
|
|
|
if len(files_with_rst) > 0:
|
|
raise ValueError(
|
|
"The following files have docstrings written in rst:\n"
|
|
+ "\n".join([f"- {f}" for f in files_with_rst])
|
|
+ "\nTo fix this run `doc-builder convert path_to_py_file` after installing `doc-builder`\n"
|
|
"(`pip install git+https://github.com/huggingface/doc-builder`)"
|
|
)
|
|
|
|
|
|
def check_deprecated_constant_is_up_to_date():
|
|
deprecated_folder = os.path.join(PATH_TO_TRANSFORMERS, "models", "deprecated")
|
|
deprecated_models = [m for m in os.listdir(deprecated_folder) if not m.startswith("_")]
|
|
|
|
constant_to_check = transformers.models.auto.configuration_auto.DEPRECATED_MODELS
|
|
message = []
|
|
missing_models = sorted(set(deprecated_models) - set(constant_to_check))
|
|
if len(missing_models) != 0:
|
|
missing_models = ", ".join(missing_models)
|
|
message.append(
|
|
"The following models are in the deprecated folder, make sur to add them to `DEPRECATED_MODELS` in "
|
|
f"`models/auto/configuration_auto.py`: {missing_models}."
|
|
)
|
|
|
|
extra_models = sorted(set(constant_to_check) - set(deprecated_models))
|
|
if len(extra_models) != 0:
|
|
extra_models = ", ".join(extra_models)
|
|
message.append(
|
|
"The following models are in the `DEPRECATED_MODELS` constant but not in the deprecated folder. Either "
|
|
f"remove them from the constant or move to the deprecated folder: {extra_models}."
|
|
)
|
|
|
|
if len(message) > 0:
|
|
raise Exception("\n".join(message))
|
|
|
|
|
|
def check_repo_quality():
|
|
"""Check all models are properly tested and documented."""
|
|
print("Checking all models are included.")
|
|
check_model_list()
|
|
print("Checking all models are public.")
|
|
check_models_are_in_init()
|
|
print("Checking all models are properly tested.")
|
|
check_all_decorator_order()
|
|
check_all_models_are_tested()
|
|
print("Checking all objects are properly documented.")
|
|
check_all_objects_are_documented()
|
|
print("Checking all models are in at least one auto class.")
|
|
check_all_models_are_auto_configured()
|
|
print("Checking all names in auto name mappings are defined.")
|
|
check_all_auto_object_names_being_defined()
|
|
print("Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.")
|
|
check_all_auto_mapping_names_in_config_mapping_names()
|
|
print("Checking all auto mappings could be imported.")
|
|
check_all_auto_mappings_importable()
|
|
print("Checking all objects are equally (across frameworks) in the main __init__.")
|
|
check_objects_being_equally_in_main_init()
|
|
print("Checking the DEPRECATED_MODELS constant is up to date.")
|
|
check_deprecated_constant_is_up_to_date()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
check_repo_quality()
|