transformers/utils/check_config_docstrings.py
Alex Brooks 623d395aff
Add Granite Speech Support (#36801)
* First pass at speech granite

Add encoder / projector, rename things

* Combine into one model file with causal lm outputs for forward

* Add loss calc

* Fix config loading

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>

* Split new / old loading logic

* Use transformers integration for loading peft adapters

* Add generation wrapper for selective lora enablement

* Add note for qformer encoder automodel

* Guard torch/audio imports in feature extractor

* Handle granite speech autoclasses

* Handle optional deps in package structure for granite speech

* Add granite pretrained model def for init

* Add dummy objects for torch/torchaudio

* Add tests for granite speech processor

* Minor formatting fixes and refactoring

* Add options for falling back to config in forward

* Tentative model docstrings for granite speech

* Fix config type

* Remove legacy load

* Allow non-lora variants for granite speech

* Override weight tying for llm

* Use text config instead of llm config

* Add output embeddings getter to fix weight tying

* Fix relative imports

* computing the number of audio features, based on the raw audio sequence.

* collating audio inputs, and keeping the original lengths.

* asserted we have text. otherwise we can't specify the audio special token.

* assering the number of audio-symbols/audios match correctly.
running get validated_audios only when audio is present

* indentation bugfix + supporting different feature lengths when expanding audio.

* redundant, done in _get_validated_text

* adapting the tests:
- we must have text (not either audio or text)
- _get_num_audio_features takes a list of raw lengths, provided it insetad.

* Minor cleanup, remove unused import

* Add more tests for batch feature processing

* Allow setting offset in rel position embeddings

* Add config option for warning if peft is not installed w/ lora

* Port blip2 qformer code into granite speech

* Add sad test for numpy arr processing

* Allow numpy arrays / tuples in granite speech processor

* Fix config type for projector

* - pad instead of creating a zeros tensor, to keep the original dtype/device (support bfloat16)
- cast input_features to the model dtype (support bfloat16)

* merge Blip2QFormerConfig to GraniteSpeechProjectorConfig

* prevent a crash when re-saving/loading the model (line 109)

* consider additional edge cases during preprocessing.

* consider additional edge cases during preprocessing.

* add features mask for batched inference (bugfix)

* Minor refactor, remove multiaudio processor tests

* Add set input/output embeddings for granite speech

* Fix feature dim check in processor test

* Pop input features in embed test for granite speech

* Small fixes for test edge cases

Add granite speech to seq2seq causal lm mapping names

* Add small tests for granite speech model

* Fix data parallelism test

* Standardize model class names

* Fix check for copies

* Fix misaligned init check

* Skip granite speech in checkpoint check

* Use default for tie_word_embeddings in granite speech

* Fix non documentation granite speech repo issues

* Fix comments and docstring checks

* Add placeholder docs for granite speech

* Fix test naming collision

* Code formatting

* Rerun torch dummy obj regen

* Fix save pretrained for granite speech

* Import sorting

* Fix tests typo

* Remove offset hack

* Pass args through encoder config

* Remove unused prune heads from blip2

* removing einsum. replaced with explicit multiplication (relative positional encodings) and sdpa attention.

* remove Sequential from ConformerFeedForward and ConformerConvModule. + fix for sdpa attention

* remove GraniteSpeechConformerScale

* rename to hidden_states

* rename conformer layers to self.layers, remove the first linear from the list to keep the list homogenous.

* move pre-norm to the attention/feedforward blocks (avoid complex module wrapping)

* adding pre_norm into forward

* feature extractor refactoring to resemble how it's done in phi4multimodal.

* rename feature_extractor to audio_processor

* bugfix: input_feature_mask fix to get the exact number tokens.

* Fix pytest decorator in processor test

* Add (disabled) integration tests for granite speech

* Fix handling of optional feature masking

* Loosen validation in processing for vLLM compatability

* Formatting fixes

* Update init structure to mirror llama

* Make granite speech projector generic

* Update test config to reflect generic projector

* Formatting fixes

* Fix typos, add license

* Fix undefined var in input processing

* Cleanup and expose ctc encoder

* Add missing config docstrings

* Better var names, type hints, etc

* Set attn context size in init

* Add max pos emb to encoder config

* Cleanup feature extractor

* Add granite speech architecture details

* Remove granite speech qformer ref

* Add paper link, explicit calc for qkv

* Calculate padding directly in depthwise conv1d init

* Raise value error instead of asserting

* Reorder class defs (classes used at top)

* Precompute relpos distances

* Run formatting

* Pass attention distances through forward

* Apply suggestions from code review

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>

* Add todo for using common batch feature extraction

* Rename audios/features

* Ensure chat template may be provided to processor

* Move granite speech docs to audio models

* Add todos for input proc refactoring

* Fix import order

* Guard torch import

* Use relative imports

* Require torch backend for processor in granite speech

* Add backend guards in feature extractor

---------

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>
Co-authored-by: Avihu Dekel <avihu.dekel@ibm.com>
Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
2025-04-11 18:52:00 +02:00

103 lines
3.7 KiB
Python

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from transformers.utils import direct_transformers_import
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_config_docstrings.py
PATH_TO_TRANSFORMERS = "src/transformers"
# This is to make sure the transformers module imported is the one in the repo.
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
# Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`.
# For example, `[google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)`
_re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"DecisionTransformerConfig",
"EncoderDecoderConfig",
"MusicgenConfig",
"RagConfig",
"SpeechEncoderDecoderConfig",
"TimmBackboneConfig",
"TimmWrapperConfig",
"VisionEncoderDecoderConfig",
"VisionTextDualEncoderConfig",
"LlamaConfig",
"GraniteConfig",
"GraniteMoeConfig",
"Qwen3MoeConfig",
"GraniteSpeechConfig",
}
def get_checkpoint_from_config_class(config_class):
checkpoint = None
# source code of `config_class`
config_source = inspect.getsource(config_class)
checkpoints = _re_checkpoint.findall(config_source)
# Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link.
# For example, `('google-bert/bert-base-uncased', 'https://huggingface.co/google-bert/bert-base-uncased')`
for ckpt_name, ckpt_link in checkpoints:
# allow the link to end with `/`
if ckpt_link.endswith("/"):
ckpt_link = ckpt_link[:-1]
# verify the checkpoint name corresponds to the checkpoint link
ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"
if ckpt_link == ckpt_link_from_name:
checkpoint = ckpt_name
break
return checkpoint
def check_config_docstrings_have_checkpoints():
configs_without_checkpoint = []
for config_class in list(CONFIG_MAPPING.values()):
# Skip deprecated models
if "models.deprecated" in config_class.__module__:
continue
checkpoint = get_checkpoint_from_config_class(config_class)
name = config_class.__name__
if checkpoint is None and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK:
configs_without_checkpoint.append(name)
if len(configs_without_checkpoint) > 0:
message = "\n".join(sorted(configs_without_checkpoint))
raise ValueError(
f"The following configurations don't contain any valid checkpoint:\n{message}\n\n"
"The requirement is to include a link pointing to one of the models of this architecture in the "
"docstring of the config classes listed above. The link should have be a markdown format like "
"[myorg/mymodel](https://huggingface.co/myorg/mymodel)."
)
if __name__ == "__main__":
check_config_docstrings_have_checkpoints()