mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add Granite Speech Support (#36801)
* First pass at speech granite Add encoder / projector, rename things * Combine into one model file with causal lm outputs for forward * Add loss calc * Fix config loading Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> * Split new / old loading logic * Use transformers integration for loading peft adapters * Add generation wrapper for selective lora enablement * Add note for qformer encoder automodel * Guard torch/audio imports in feature extractor * Handle granite speech autoclasses * Handle optional deps in package structure for granite speech * Add granite pretrained model def for init * Add dummy objects for torch/torchaudio * Add tests for granite speech processor * Minor formatting fixes and refactoring * Add options for falling back to config in forward * Tentative model docstrings for granite speech * Fix config type * Remove legacy load * Allow non-lora variants for granite speech * Override weight tying for llm * Use text config instead of llm config * Add output embeddings getter to fix weight tying * Fix relative imports * computing the number of audio features, based on the raw audio sequence. * collating audio inputs, and keeping the original lengths. * asserted we have text. otherwise we can't specify the audio special token. * assering the number of audio-symbols/audios match correctly. running get validated_audios only when audio is present * indentation bugfix + supporting different feature lengths when expanding audio. * redundant, done in _get_validated_text * adapting the tests: - we must have text (not either audio or text) - _get_num_audio_features takes a list of raw lengths, provided it insetad. * Minor cleanup, remove unused import * Add more tests for batch feature processing * Allow setting offset in rel position embeddings * Add config option for warning if peft is not installed w/ lora * Port blip2 qformer code into granite speech * Add sad test for numpy arr processing * Allow numpy arrays / tuples in granite speech processor * Fix config type for projector * - pad instead of creating a zeros tensor, to keep the original dtype/device (support bfloat16) - cast input_features to the model dtype (support bfloat16) * merge Blip2QFormerConfig to GraniteSpeechProjectorConfig * prevent a crash when re-saving/loading the model (line 109) * consider additional edge cases during preprocessing. * consider additional edge cases during preprocessing. * add features mask for batched inference (bugfix) * Minor refactor, remove multiaudio processor tests * Add set input/output embeddings for granite speech * Fix feature dim check in processor test * Pop input features in embed test for granite speech * Small fixes for test edge cases Add granite speech to seq2seq causal lm mapping names * Add small tests for granite speech model * Fix data parallelism test * Standardize model class names * Fix check for copies * Fix misaligned init check * Skip granite speech in checkpoint check * Use default for tie_word_embeddings in granite speech * Fix non documentation granite speech repo issues * Fix comments and docstring checks * Add placeholder docs for granite speech * Fix test naming collision * Code formatting * Rerun torch dummy obj regen * Fix save pretrained for granite speech * Import sorting * Fix tests typo * Remove offset hack * Pass args through encoder config * Remove unused prune heads from blip2 * removing einsum. replaced with explicit multiplication (relative positional encodings) and sdpa attention. * remove Sequential from ConformerFeedForward and ConformerConvModule. + fix for sdpa attention * remove GraniteSpeechConformerScale * rename to hidden_states * rename conformer layers to self.layers, remove the first linear from the list to keep the list homogenous. * move pre-norm to the attention/feedforward blocks (avoid complex module wrapping) * adding pre_norm into forward * feature extractor refactoring to resemble how it's done in phi4multimodal. * rename feature_extractor to audio_processor * bugfix: input_feature_mask fix to get the exact number tokens. * Fix pytest decorator in processor test * Add (disabled) integration tests for granite speech * Fix handling of optional feature masking * Loosen validation in processing for vLLM compatability * Formatting fixes * Update init structure to mirror llama * Make granite speech projector generic * Update test config to reflect generic projector * Formatting fixes * Fix typos, add license * Fix undefined var in input processing * Cleanup and expose ctc encoder * Add missing config docstrings * Better var names, type hints, etc * Set attn context size in init * Add max pos emb to encoder config * Cleanup feature extractor * Add granite speech architecture details * Remove granite speech qformer ref * Add paper link, explicit calc for qkv * Calculate padding directly in depthwise conv1d init * Raise value error instead of asserting * Reorder class defs (classes used at top) * Precompute relpos distances * Run formatting * Pass attention distances through forward * Apply suggestions from code review Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> * Add todo for using common batch feature extraction * Rename audios/features * Ensure chat template may be provided to processor * Move granite speech docs to audio models * Add todos for input proc refactoring * Fix import order * Guard torch import * Use relative imports * Require torch backend for processor in granite speech * Add backend guards in feature extractor --------- Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Co-authored-by: Avihu Dekel <avihu.dekel@ibm.com> Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
This commit is contained in:
parent
435f88f1db
commit
623d395aff
@ -823,6 +823,8 @@
|
||||
title: EnCodec
|
||||
- local: model_doc/fastspeech2_conformer
|
||||
title: FastSpeech2Conformer
|
||||
- local: model_doc/granite_speech
|
||||
title: GraniteSpeech
|
||||
- local: model_doc/hubert
|
||||
title: Hubert
|
||||
- local: model_doc/mctct
|
||||
|
68
docs/source/en/model_doc/granite_speech.md
Normal file
68
docs/source/en/model_doc/granite_speech.md
Normal file
@ -0,0 +1,68 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Granite Speech
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
The Granite Speech model is a multimodal language model, consisting of a speech encoder, speech projector, large language model, and LoRA adapter(s). More details regarding each component for the current (Granite 3.2 Speech) model architecture may be found below.
|
||||
|
||||
1. Speech Encoder: A [Conformer](https://arxiv.org/abs/2005.08100) encoder trained with Connectionist Temporal Classification (CTC) on character-level targets on ASR corpora. The encoder uses block-attention and self-conditioned CTC from the middle layer.
|
||||
|
||||
2. Speech Projector: A query transformer (q-former) operating on the outputs of the last encoder block. The encoder and projector temporally downsample the audio features to be merged into the multimodal embeddings to be processed by the llm.
|
||||
|
||||
3. Large Language Model: The Granite Speech model leverages Granite LLMs, which were originally proposed in [this paper](https://arxiv.org/abs/2408.13359).
|
||||
|
||||
4. LoRA adapter(s): The Granite Speech model contains a modality specific LoRA, which will be enabled when audio features are provided, and disabled otherwise.
|
||||
|
||||
|
||||
Note that most of the aforementioned components are implemented generically to enable compatability and potential integration with other model architectures in transformers.
|
||||
|
||||
|
||||
This model was contributed by [Alexander Brooks](https://huggingface.co/abrooks9944), [Avihu Dekel](https://huggingface.co/Avihu), and [George Saon](https://huggingface.co/gsaon).
|
||||
|
||||
## Usage tips
|
||||
- This model bundles its own LoRA adapter, which will be automatically loaded and enabled/disabled as needed during inference calls. Be sure to install [PEFT](https://github.com/huggingface/peft) to ensure the LoRA is correctly applied!
|
||||
|
||||
<!-- TODO (@alex-jw-brooks) Add an example here once the model compatible with the transformers implementation is released -->
|
||||
|
||||
## GraniteSpeechConfig
|
||||
|
||||
[[autodoc]] GraniteSpeechConfig
|
||||
|
||||
|
||||
## GraniteSpeechEncoderConfig
|
||||
|
||||
[[autodoc]] GraniteSpeechEncoderConfig
|
||||
|
||||
|
||||
## GraniteSpeechProcessor
|
||||
|
||||
[[autodoc]] GraniteSpeechProcessor
|
||||
|
||||
|
||||
## GraniteSpeechFeatureExtractor
|
||||
|
||||
[[autodoc]] GraniteSpeechFeatureExtractor
|
||||
|
||||
|
||||
## GraniteSpeechForConditionalGeneration
|
||||
|
||||
[[autodoc]] GraniteSpeechForConditionalGeneration
|
||||
- forward
|
@ -125,6 +125,7 @@ if TYPE_CHECKING:
|
||||
from .gpt_sw3 import *
|
||||
from .gptj import *
|
||||
from .granite import *
|
||||
from .granite_speech import *
|
||||
from .granitemoe import *
|
||||
from .granitemoeshared import *
|
||||
from .grounding_dino import *
|
||||
|
@ -142,6 +142,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("gptj", "GPTJConfig"),
|
||||
("gptsan-japanese", "GPTSanJapaneseConfig"),
|
||||
("granite", "GraniteConfig"),
|
||||
("granite_speech", "GraniteSpeechConfig"),
|
||||
("granitemoe", "GraniteMoeConfig"),
|
||||
("granitemoeshared", "GraniteMoeSharedConfig"),
|
||||
("granitevision", "LlavaNextConfig"),
|
||||
@ -491,6 +492,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("gptj", "GPT-J"),
|
||||
("gptsan-japanese", "GPTSAN-japanese"),
|
||||
("granite", "Granite"),
|
||||
("granite_speech", "GraniteSpeech"),
|
||||
("granitemoe", "GraniteMoeMoe"),
|
||||
("granitemoeshared", "GraniteMoeSharedMoe"),
|
||||
("granitevision", "LLaVA-NeXT"),
|
||||
|
@ -61,6 +61,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("encodec", "EncodecFeatureExtractor"),
|
||||
("flava", "FlavaFeatureExtractor"),
|
||||
("glpn", "GLPNFeatureExtractor"),
|
||||
("granite_speech", "GraniteSpeechFeatureExtractor"),
|
||||
("groupvit", "CLIPFeatureExtractor"),
|
||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||
("imagegpt", "ImageGPTFeatureExtractor"),
|
||||
|
@ -973,6 +973,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("encoder-decoder", "EncoderDecoderModel"),
|
||||
("fsmt", "FSMTForConditionalGeneration"),
|
||||
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
|
||||
("granite_speech", "GraniteSpeechForConditionalGeneration"),
|
||||
("led", "LEDForConditionalGeneration"),
|
||||
("longt5", "LongT5ForConditionalGeneration"),
|
||||
("m2m_100", "M2M100ForConditionalGeneration"),
|
||||
@ -997,6 +998,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("granite_speech", "GraniteSpeechForConditionalGeneration"),
|
||||
("moonshine", "MoonshineForConditionalGeneration"),
|
||||
("pop2piano", "Pop2PianoForConditionalGeneration"),
|
||||
("seamless_m4t", "SeamlessM4TForSpeechToText"),
|
||||
|
@ -66,6 +66,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("gemma3", "Gemma3Processor"),
|
||||
("git", "GitProcessor"),
|
||||
("got_ocr2", "GotOcr2Processor"),
|
||||
("granite_speech", "GraniteSpeechProcessor"),
|
||||
("grounding-dino", "GroundingDinoProcessor"),
|
||||
("groupvit", "CLIPProcessor"),
|
||||
("hubert", "Wav2Vec2Processor"),
|
||||
|
29
src/transformers/models/granite_speech/__init__.py
Normal file
29
src/transformers/models/granite_speech/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_granite_speech import *
|
||||
from .feature_extraction_granite_speech import *
|
||||
from .modeling_granite_speech import *
|
||||
from .processing_granite_speech import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
@ -0,0 +1,197 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Config class for Granite Speech."""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
class GraniteSpeechEncoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GraniteSpeechCTCEncoder`]. It is used to instantiate
|
||||
a Granite Speech audio encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the dfefaults will yield a similar configuration to that of the audio encoder of the Granite Speech
|
||||
architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
input_dim (`int`, *optional*, defaults to 160):
|
||||
Dimension of the first hidden layer of the encoder.
|
||||
num_layers (`int`, *optional*, defaults to 10):
|
||||
Number of encoder blocks.
|
||||
hidden_dim (`int`, *optional*, defaults to 1024):
|
||||
The size of the intermediate layers in the conformer encoder.
|
||||
feedforward_mult (`int`, *optional*, defaults to 4):
|
||||
Multiplier for the up/down projections in the encoder's feedforward layers;
|
||||
The projections will have intermediate dim of size `hidden_dim * feedforward_mult`.
|
||||
num_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
dim_head (`int`, *optional*, defaults to 128):
|
||||
Dimension of attention heads for each attention layer in the Transformer encoder.
|
||||
output_dim (`int`, *optional*, defaults to 42):
|
||||
Intermediate dimension of the feedforward projections in the conformer
|
||||
to be added to every other encoder block's output.
|
||||
context_size (`int`, *optional*, defaults to 200):
|
||||
Context size to be used in conformer attention.
|
||||
max_pos_emb (`int`, *optional*, defaults to 512):
|
||||
Max pos embeds to be used in attention (shaw's relative positional encoding).
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for fully connected layers in the encoder.
|
||||
conv_kernel_size (`int`, *optional*, defaults to 15):
|
||||
Kernel size to be used for 1D convolution in each conformer block.
|
||||
conv_expansion_factor (`int`, *optional*, defaults to 2):
|
||||
Intermediate dimension to be used in conformer convolutions.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import GraniteSpeechEncoderConfig, GraniteSpeechCTCEncoder
|
||||
|
||||
>>> # Initializing a GraniteSpeechEncoderConfig
|
||||
>>> configuration = GraniteSpeechEncoderConfig()
|
||||
|
||||
>>> # Initializing a GraniteSpeechCTCEncoder (with random weights)
|
||||
>>> model = GraniteSpeechCTCEncoder(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "granite_speech_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim=160,
|
||||
num_layers=10,
|
||||
hidden_dim=1024,
|
||||
feedforward_mult=4,
|
||||
num_heads=8,
|
||||
dim_head=128,
|
||||
output_dim=42,
|
||||
context_size=200,
|
||||
max_pos_emb=512,
|
||||
dropout=0.1,
|
||||
conv_kernel_size=15,
|
||||
conv_expansion_factor=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.input_dim = input_dim
|
||||
self.num_layers = num_layers
|
||||
self.hidden_dim = hidden_dim
|
||||
self.feedforward_mult = feedforward_mult
|
||||
self.num_heads = num_heads
|
||||
self.dim_head = dim_head
|
||||
self.output_dim = output_dim
|
||||
self.context_size = context_size
|
||||
self.dropout = dropout
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_expansion_factor = conv_expansion_factor
|
||||
self.max_pos_emb = max_pos_emb
|
||||
|
||||
|
||||
class GraniteSpeechConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GraniteSpeechForConditionalGeneration`]. It is used to instantiate an
|
||||
Granite Speech model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `GraniteConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
encoder_config (`GraniteSpeechEncoderConfig`, *optional*):
|
||||
The config object or dictionary of the Granite Speech CTC Encoder.
|
||||
projector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Blip2QFormerConfig`):
|
||||
The config object or dictionary of the audio projector.
|
||||
audio_token_index (`int`, *optional*, defaults to 49155):
|
||||
The audio token index to encode the audio prompt.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
has_lora_adapter (`bool`, *optional*, defaults to `True`):
|
||||
Indicates whether or not the model has a lora adapter that should only
|
||||
be activate when processing audio inputs.
|
||||
downsample_rate (`int`, *optional*, defaults to 5):
|
||||
Downsample rate for the audio feature extractor.
|
||||
window_size (`int`, *optional*, defaults to 15):
|
||||
Window size for the audio feature projector.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import GraniteSpeechConfig, GraniteSpeechForConditionalGeneration
|
||||
|
||||
>>> # Initializing a GraniteSpeechConfig
|
||||
>>> configuration = GraniteSpeechConfig()
|
||||
|
||||
>>> # Initializing a GraniteSpeechForConditionalGeneration (with random weights)
|
||||
>>> model = GraniteSpeechForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "granite_speech"
|
||||
sub_configs = {
|
||||
"text_config": AutoConfig,
|
||||
"encoder_config": GraniteSpeechEncoderConfig,
|
||||
"projector_config": AutoConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
encoder_config=None,
|
||||
projector_config=None,
|
||||
audio_token_index=49155,
|
||||
initializer_range=0.02,
|
||||
has_lora_adapter=True,
|
||||
downsample_rate=5,
|
||||
window_size=15,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "granite"
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["granite"]()
|
||||
|
||||
if isinstance(projector_config, dict):
|
||||
projector_config["model_type"] = (
|
||||
projector_config["model_type"] if "model_type" in projector_config else "blip_2_qformer"
|
||||
)
|
||||
projector_config = CONFIG_MAPPING[projector_config["model_type"]](**projector_config)
|
||||
elif projector_config is None:
|
||||
projector_config = CONFIG_MAPPING["blip_2_qformer"]()
|
||||
|
||||
if not isinstance(encoder_config, GraniteSpeechEncoderConfig):
|
||||
encoder_config = {} if encoder_config is None else encoder_config
|
||||
encoder_config = GraniteSpeechEncoderConfig(**encoder_config)
|
||||
|
||||
self.text_config = text_config
|
||||
self.encoder_config = encoder_config
|
||||
self.projector_config = projector_config
|
||||
self.audio_token_index = audio_token_index
|
||||
self.initializer_range = initializer_range
|
||||
self.has_lora_adapter = has_lora_adapter
|
||||
self.downsample_rate = downsample_rate
|
||||
self.window_size = window_size
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["GraniteSpeechEncoderConfig", "GraniteSpeechConfig"]
|
@ -0,0 +1,208 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Feature extractor class for Granite Speech."""
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
from ...tokenization_utils_base import AudioInput
|
||||
from ...utils import is_torch_available, is_torchaudio_available, logging
|
||||
from ...utils.import_utils import requires_backends
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchaudio_available():
|
||||
import torchaudio
|
||||
|
||||
|
||||
class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
|
||||
model_input_names = ["input_features"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate: int = 16000,
|
||||
n_fft: int = 512,
|
||||
win_length: int = 400,
|
||||
hop_length: int = 160,
|
||||
n_mels: int = 80,
|
||||
projector_window_size: int = 15,
|
||||
projector_downsample_rate: int = 5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.melspec_kwargs = {
|
||||
"sample_rate": sampling_rate,
|
||||
"n_fft": n_fft,
|
||||
"win_length": win_length,
|
||||
"hop_length": hop_length,
|
||||
"n_mels": n_mels,
|
||||
}
|
||||
# Currently lazily initialized
|
||||
self.melspec = None
|
||||
self.projector_window_size = projector_window_size
|
||||
self.projector_downsample_rate = projector_downsample_rate
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
audios: AudioInput,
|
||||
device: Optional[str] = "cpu",
|
||||
) -> BatchFeature:
|
||||
requires_backends(self, ["torchaudio"])
|
||||
|
||||
speech_inputs = {}
|
||||
batched_audio, audio_lengths = self._get_audios_and_audio_lengths(audios)
|
||||
speech_inputs["input_features"] = self._extract_mel_spectrograms(
|
||||
batched_audio,
|
||||
device=device,
|
||||
)
|
||||
audio_embed_sizes = self._get_num_audio_features(audio_lengths)
|
||||
speech_inputs["audio_embed_sizes"] = audio_embed_sizes
|
||||
# TODO (@alex-jw-brooks): Currently input_features_mask is not
|
||||
# a great name, because input_features and input_features_mask
|
||||
# have different shapes (before/after the projector).
|
||||
#
|
||||
# We should align this with other multimodal models, e.g,. llava
|
||||
# and qwen2audio and refactor this to ensure input_feature_mask
|
||||
# has the same dimensionality as input_features, or compute it in
|
||||
# the model based on the audio embedding sizes (since we do not
|
||||
# have an attention mask for the audio features to infer padding from).
|
||||
speech_inputs["input_features_mask"] = torch.arange(max(audio_embed_sizes)).view(1, -1) < torch.tensor(
|
||||
audio_embed_sizes
|
||||
).view(-1, 1)
|
||||
return BatchFeature(data=speech_inputs)
|
||||
|
||||
def _ensure_melspec_transform_is_initialized(self):
|
||||
"""
|
||||
Ensures the mel spectrogram transform on this instance is initialized.
|
||||
|
||||
We do this for now since some logging explodes since the mel spectrogram
|
||||
transform is not JSON serializable.
|
||||
"""
|
||||
requires_backends(self, ["torchaudio"])
|
||||
|
||||
if self.melspec is None:
|
||||
# TODO (@alex-jw-brooks / @eustlb) move this to common batch
|
||||
# feature extraction in audio utils once they are written!
|
||||
self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
|
||||
|
||||
def _extract_mel_spectrograms(self, audio: "torch.Tensor", device="cpu"):
|
||||
"""
|
||||
Compute the Mel features to be passed to the conformer encoder.
|
||||
"""
|
||||
requires_backends(self, ["torchaudio"])
|
||||
|
||||
# Initialize the mel spectrogram if isn't not already and
|
||||
# move the melspec / audio to the computation device.
|
||||
self._ensure_melspec_transform_is_initialized()
|
||||
if device is not None:
|
||||
melspec = self.melspec.to(device)
|
||||
audio = audio.to(device)
|
||||
else:
|
||||
melspec = self.melspec
|
||||
|
||||
bsz = audio.shape[0]
|
||||
with torch.no_grad():
|
||||
# Compute mel features
|
||||
mel = melspec(audio.float())
|
||||
logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_()
|
||||
mx = logmel.amax(dim=(-2, -1), keepdim=True)
|
||||
logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1)
|
||||
# remove last frame if odd
|
||||
if logmel.shape[1] % 2 == 1:
|
||||
logmel = logmel[:, :-1]
|
||||
|
||||
# stacking and skipping by 2
|
||||
audio = logmel.reshape(bsz, -1, 2 * logmel.shape[-1])
|
||||
|
||||
if audio.device != "cpu":
|
||||
return audio.detach().cpu()
|
||||
return audio
|
||||
|
||||
def _get_num_audio_features(self, audio_lengths: Sequence[int]) -> Sequence[int]:
|
||||
"""
|
||||
Gets the (variable length) number of features (i.e., projector output) for the sequences
|
||||
being considered.
|
||||
|
||||
Args:
|
||||
audio_lengths (`Sequence[int]`):
|
||||
Sequence of one or more raw audio lengths.
|
||||
"""
|
||||
hop_length = self.melspec_kwargs["hop_length"]
|
||||
effective_window_size = self.projector_window_size // self.projector_downsample_rate
|
||||
|
||||
projector_lengths = []
|
||||
for raw_length in audio_lengths:
|
||||
# mel sequence length computation
|
||||
mel_length = raw_length // hop_length + 1
|
||||
# encoder frame takes two mel features
|
||||
encoder_length = mel_length // 2
|
||||
nblocks = math.ceil(encoder_length / self.projector_window_size)
|
||||
# projector output length
|
||||
projector_length = nblocks * effective_window_size
|
||||
projector_lengths.append(projector_length)
|
||||
|
||||
return projector_lengths
|
||||
|
||||
def _get_audios_and_audio_lengths(self, audios: AudioInput) -> Sequence["torch.Tensor", Sequence[int]]:
|
||||
"""
|
||||
Coerces audio inputs to torch tensors and extracts audio lengths prior to stacking.
|
||||
|
||||
Args:
|
||||
audios (`AudioInput`):
|
||||
Audio sequence, numpy array, or torch tensor.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
# Coerce to PyTorch tensors if we have numpy arrays, since
|
||||
# currently we have a dependency on torch/torchaudio anyway
|
||||
if isinstance(audios, np.ndarray):
|
||||
audios = torch.from_numpy(audios)
|
||||
elif isinstance(audios, Sequence) and isinstance(audios[0], np.ndarray):
|
||||
audios = [torch.from_numpy(arr) for arr in audios]
|
||||
|
||||
if isinstance(audios, torch.Tensor):
|
||||
if audios.ndim == 1:
|
||||
audios = audios.unsqueeze(0)
|
||||
if not torch.is_floating_point(audios):
|
||||
raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1")
|
||||
|
||||
if audios.shape[0] > 1:
|
||||
logger.warning("Audio samples are already collated; assuming they all have the same length")
|
||||
lengths = [audios.shape[-1]] * audios.shape[0]
|
||||
return audios, lengths
|
||||
|
||||
elif isinstance(audios, Sequence) and isinstance(audios[0], torch.Tensor):
|
||||
if not torch.is_floating_point(audios[0]):
|
||||
raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1")
|
||||
lengths = [audio.shape[-1] for audio in audios]
|
||||
padding = [max(lengths) - length for length in lengths]
|
||||
# ensure all audios have a batch dimension:
|
||||
audios = [audio.view(1, -1) for audio in audios]
|
||||
padded = [torch.nn.functional.pad(audio, (0, pad)) for audio, pad in zip(audios, padding)]
|
||||
audios = torch.cat(padded, dim=0)
|
||||
return audios, lengths
|
||||
|
||||
raise TypeError("Invalid audio provided. Audio should be a one or more torch tensors or numpy arrays")
|
||||
|
||||
|
||||
__all__ = ["GraniteSpeechFeatureExtractor"]
|
@ -0,0 +1,673 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_peft_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_granite_speech import (
|
||||
GraniteSpeechConfig,
|
||||
GraniteSpeechEncoderConfig,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "GraniteSpeechConfig"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraniteSpeechCausalLMOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for LlavaNext causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
### Projector
|
||||
class GraniteSpeechEncoderProjector(nn.Module):
|
||||
def __init__(self, config: GraniteSpeechConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.projector_config.hidden_size
|
||||
self.downsample_rate = config.downsample_rate
|
||||
self.window_size = config.window_size
|
||||
self.num_queries = config.window_size // config.downsample_rate
|
||||
|
||||
self.query = nn.Parameter(torch.zeros(1, self.num_queries, config.projector_config.hidden_size))
|
||||
self.query.data.normal_(mean=0.0, std=1.0)
|
||||
|
||||
# By default, this will be a blip_2_qformer config
|
||||
self.qformer = AutoModel.from_config(config.projector_config)
|
||||
self.linear = nn.Linear(config.projector_config.hidden_size, config.text_config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, dim = hidden_states.size()
|
||||
nblocks = math.ceil(seq_len / self.window_size)
|
||||
pad = nblocks * self.window_size - seq_len
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
|
||||
hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim)
|
||||
|
||||
query_output = self.qformer(
|
||||
query_embeds=self.query.data,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
return_dict=True,
|
||||
)
|
||||
query_proj = self.linear(
|
||||
query_output.last_hidden_state.view(batch_size, nblocks * self.window_size // self.downsample_rate, -1)
|
||||
)
|
||||
return query_proj
|
||||
|
||||
|
||||
### Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git
|
||||
class GraniteSpeechConformerFeedForward(nn.Module):
|
||||
"""Feedforward module for conformer encoder blocks."""
|
||||
|
||||
def __init__(self, config: GraniteSpeechEncoderConfig):
|
||||
super().__init__()
|
||||
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||
self.up_proj = nn.Linear(config.hidden_dim, config.hidden_dim * config.feedforward_mult)
|
||||
self.silu = nn.SiLU()
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.down_proj = nn.Linear(config.hidden_dim * config.feedforward_mult, config.hidden_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(hidden_states)
|
||||
hidden_states = self.up_proj(hidden_states)
|
||||
hidden_states = self.dropout(self.silu(hidden_states))
|
||||
hidden_states = self.down_proj(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GraniteSpeechConformerAttention(nn.Module):
|
||||
"""Attention for conformer blocks using Shaw's relative positional embeddings.
|
||||
See the following [paper](https://arxiv.org/pdf/1803.02155) for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GraniteSpeechEncoderConfig):
|
||||
super().__init__()
|
||||
|
||||
inner_dim = config.dim_head * config.num_heads
|
||||
self.max_pos_emb = config.max_pos_emb
|
||||
self.context_size = config.context_size
|
||||
self.num_heads = config.num_heads
|
||||
self.dim_head = config.dim_head
|
||||
self.scale = self.dim_head**-0.5
|
||||
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
|
||||
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
if self.context_size <= 0 or self.context_size > self.max_pos_emb:
|
||||
raise ValueError("Context size is either less than 0 or exceeds the max_pos_emb")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(hidden_states)
|
||||
bsz, num_features, _ = hidden_states.shape
|
||||
|
||||
num_blocks = math.ceil(num_features / self.context_size)
|
||||
remainder = num_features % self.context_size
|
||||
if remainder > 0:
|
||||
# right padding to reach block size
|
||||
hidden_states = torch.nn.functional.pad(hidden_states, (0, 0, 0, self.context_size - remainder))
|
||||
|
||||
query_states = self.to_q(hidden_states)
|
||||
key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
|
||||
|
||||
query_states = query_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3)
|
||||
key_states = key_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3)
|
||||
value_states = value_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3)
|
||||
|
||||
# shaw's relative positional embedding
|
||||
dist = attention_dists.to(hidden_states.device)
|
||||
rel_pos_emb = self.rel_pos_emb(dist)
|
||||
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
|
||||
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
|
||||
|
||||
if remainder > 0:
|
||||
# masked attention in the extended block
|
||||
mask = torch.ones(self.context_size, self.context_size, dtype=bool, device=hidden_states.device)
|
||||
mask[:remainder, :remainder] = 0
|
||||
mask_value = -torch.finfo(pos_attn.dtype).max
|
||||
pos_attn[:, -1, :].masked_fill_(mask, mask_value)
|
||||
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
|
||||
out = F.scaled_dot_product_attention(
|
||||
query_states, key_states, value_states, attn_mask=pos_attn, scale=self.scale
|
||||
)
|
||||
out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
|
||||
out = self.to_out(out[:, :num_features, :])
|
||||
return self.dropout(out)
|
||||
|
||||
|
||||
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
|
||||
"""Wrapper for padded 1D pointwise convolution."""
|
||||
|
||||
def __init__(self, chan_in: int, chan_out: int, kernel_size: int):
|
||||
super().__init__()
|
||||
# Padding for the 1D conv is symmetric or close (i.e., offset by one).
|
||||
pad = kernel_size // 2
|
||||
pad_offset = (kernel_size + 1) % 2
|
||||
self.padding = (pad, pad - pad_offset)
|
||||
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = F.pad(hidden_states, self.padding)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class GraniteSpeechConformerConvModule(nn.Module):
|
||||
"""Conformer conv module consisting of several 1D/depthwise 1D convolutional layers."""
|
||||
|
||||
def __init__(self, config: GraniteSpeechEncoderConfig):
|
||||
super().__init__()
|
||||
inner_dim = config.hidden_dim * config.conv_expansion_factor
|
||||
|
||||
self.norm = nn.LayerNorm(config.hidden_dim)
|
||||
self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
|
||||
self.glu = nn.GLU(dim=1)
|
||||
self.depth_conv = GraniteSpeechConformerDepthWiseConv1d(
|
||||
inner_dim,
|
||||
inner_dim,
|
||||
kernel_size=config.conv_kernel_size,
|
||||
)
|
||||
self.silu = nn.SiLU()
|
||||
self.batch_norm = nn.BatchNorm1d(inner_dim)
|
||||
self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.up_conv(hidden_states.permute(0, 2, 1))
|
||||
hidden_states = self.glu(hidden_states)
|
||||
hidden_states = self.depth_conv(hidden_states)
|
||||
hidden_states = self.silu(self.batch_norm(hidden_states))
|
||||
hidden_states = self.down_conv(hidden_states).permute(0, 2, 1)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GraniteSpeechConformerBlock(nn.Module):
|
||||
"""Conformer block, consisting largely of linear layers, attention, and convolutional layers."""
|
||||
|
||||
def __init__(self, config: GraniteSpeechEncoderConfig):
|
||||
super().__init__()
|
||||
self.ff1 = GraniteSpeechConformerFeedForward(config)
|
||||
self.attn = GraniteSpeechConformerAttention(config)
|
||||
self.conv = GraniteSpeechConformerConvModule(config)
|
||||
self.ff2 = GraniteSpeechConformerFeedForward(config)
|
||||
self.post_norm = nn.LayerNorm(config.hidden_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
|
||||
hidden_states = self.attn(hidden_states, attention_dists=attention_dists) + hidden_states
|
||||
hidden_states = self.conv(hidden_states) + hidden_states
|
||||
hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
|
||||
hidden_states = self.post_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GraniteSpeechCTCEncoder(nn.Module):
|
||||
def __init__(self, config: GraniteSpeechEncoderConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Precompute clamped relative positional encoding distances
|
||||
seq = torch.arange(config.context_size)
|
||||
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
|
||||
self.attention_dists = torch.clamp(relpos_dist, -config.context_size, config.context_size) + config.max_pos_emb
|
||||
|
||||
self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True)
|
||||
self.layers = nn.ModuleList([GraniteSpeechConformerBlock(config) for _ in range(config.num_layers)])
|
||||
|
||||
self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True)
|
||||
self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True)
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
hidden_states = self.input_linear(hidden_states)
|
||||
for idx, layer in enumerate(self.layers, start=1):
|
||||
hidden_states = layer(hidden_states, attention_dists=self.attention_dists)
|
||||
|
||||
if idx == self.num_layers // 2:
|
||||
hidden_states_mid = hidden_states.clone()
|
||||
hidden_states_mid = self.out(hidden_states_mid)
|
||||
hidden_states += self.out_mid(nn.Softmax(dim=-1)(hidden_states_mid))
|
||||
return hidden_states
|
||||
|
||||
|
||||
GRANITE_SPEECH_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config (`GraniteSpeechConfig`):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Granite Speech Model outputting raw hidden-states without any specific head on top.",
|
||||
GRANITE_SPEECH_START_DOCSTRING,
|
||||
)
|
||||
class GraniteSpeechPreTrainedModel(PreTrainedModel):
|
||||
config_class = GraniteSpeechConfig
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights."""
|
||||
std = self.config.initializer_range
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
GRANITE_SPEECH_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
input_features (`torch.FloatTensor` of shape `(batch_size, audio seq len, mel feat dim)):
|
||||
The tensors corresponding to the input audios. input features can be obtained using
|
||||
[`AutoFeatureExtractor`]. See [`GraniteSpeechFeatureExtractor.__call__`] for details.
|
||||
[`GraniteSpeechProcessor`] uses [`GraniteSpeechFeatureExtractor`] for processing audio.
|
||||
input_mask (`torch.Tensor`, *optional*)
|
||||
Mask for extracted audio features that should should be ignored when creating the merged
|
||||
multimodal representation (i.e., due to padding).
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The Granite Speech model, which consists of an audio encoder, projector, and language model.""",
|
||||
GRANITE_SPEECH_START_DOCSTRING,
|
||||
)
|
||||
class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin):
|
||||
def __init__(self, config: GraniteSpeechConfig):
|
||||
super().__init__(config)
|
||||
# NOTE: It doesn't matter when we initialize from config, but we should be careful
|
||||
# to make sure this does not pick up the adapter_config if in the future we use
|
||||
# from_pretrained or something similar, since that should be set by the composite
|
||||
# model; don't need to consider it twice
|
||||
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
self.encoder = GraniteSpeechCTCEncoder(config.encoder_config)
|
||||
self.projector = GraniteSpeechEncoderProjector(config)
|
||||
|
||||
if config.has_lora_adapter and not is_peft_available():
|
||||
logger.warning(
|
||||
"Config indicates that a lora adapter should be present, but "
|
||||
"peft is not installed; this will cause the model to perform "
|
||||
"incorrectly when audio inputs are provided. Please install "
|
||||
"peft and reload the model!"
|
||||
)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
def get_audio_features(self, input_features: torch.Tensor) -> torch.Tensor:
|
||||
"""Get the audio features to merged into the multimodal embeddings."""
|
||||
encoder_embeds = self.encoder(input_features)
|
||||
projected_embeds = self.projector(encoder_embeds)
|
||||
return projected_embeds
|
||||
|
||||
@add_start_docstrings_to_model_forward(GRANITE_SPEECH_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=GraniteSpeechCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_features: torch.FloatTensor = None,
|
||||
input_features_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple[torch.Tensor], GraniteSpeechCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
# TODO (@alex-jw-brooks) add an example to this docstring once models are released
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if input_features is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both input_features and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Get the base embeddings; set all audio tokens to 0 index
|
||||
# to avoid out of vocabulary issues with the LLM embedding.
|
||||
# Audio features will be masked into is_audio_idx indices later.
|
||||
is_audio_idx = input_ids == self.config.audio_token_index
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[is_audio_idx] = 0
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if input_features is not None:
|
||||
if input_features.dtype != self.dtype:
|
||||
input_features = input_features.to(self.dtype)
|
||||
# Get the audio features from the encoder / projector
|
||||
audio_embeds = self.get_audio_features(input_features)
|
||||
|
||||
# Merge the audio features into the LLM embeddings
|
||||
inputs_embeds = self.get_merged_audio_embeddings(
|
||||
input_ids=input_ids,
|
||||
audio_features=audio_embeds,
|
||||
input_features_mask=input_features_mask,
|
||||
)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
)
|
||||
logits = outputs[0]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||||
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
||||
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return GraniteSpeechCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
input_features=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward audio inputs to the model
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# If we're in cached decoding stage, input_features should be None because
|
||||
# input ids do not contain special audio token anymore Otherwise we need
|
||||
# input feature values to be passed to the model
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["input_features"] = input_features
|
||||
return model_inputs
|
||||
|
||||
def get_merged_audio_embeddings(
|
||||
self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adds the audio token to the model's LLM vocabulary so that we can pass it
|
||||
through the tokenizer; it's assumed that the embeddings corresponding to the
|
||||
<|audio|> token will be clobbered with speech features.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.Tensor`):
|
||||
Input IDs containing one or more audio tokens.
|
||||
audio_features (`torch.Tensor`):
|
||||
Audio features to be masked into the language embeddings to form multimodal embeddings.
|
||||
input_features_mask (`torch.Tensor`, *optional*, defaults to `None`)
|
||||
Mask to be applied to audio features prior to scattering into the language embeddings.
|
||||
"""
|
||||
is_audio_index = input_ids == self.config.audio_token_index
|
||||
llm_input_ids = torch.where(is_audio_index, 0, input_ids)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size]
|
||||
|
||||
# Mask the audio features into the text embeddings
|
||||
special_audio_mask = is_audio_index.unsqueeze(-1)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
if input_features_mask is not None:
|
||||
if torch.all(is_audio_index.int().sum(dim=1) != input_features_mask.int().sum(dim=1)).item():
|
||||
raise ValueError("Number of audio tokens does not match number of audio features")
|
||||
|
||||
audio_features = audio_features[input_features_mask]
|
||||
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
special_audio_mask,
|
||||
audio_features,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.LongTensor:
|
||||
# This model is expected to have a lora adapter, which is only
|
||||
# enabled when considering audio inputs. As such, we override generate
|
||||
# to conditionally enable / disable the lora adapter based on whether
|
||||
# or not any input features were provided.
|
||||
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
if is_peft_available and self._hf_peft_config_loaded:
|
||||
if input_features is not None:
|
||||
self.enable_adapters()
|
||||
else:
|
||||
self.disable_adapters()
|
||||
return super().generate(*args, input_features=input_features, **kwargs)
|
||||
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
# overwrite save_pretrained to first save the adapter if we have one
|
||||
# NOTE - this will use the base model path we are exporting in the lora
|
||||
# adapter, which may not necessarily be the best behavior, but for now
|
||||
# we keep this for portability, since using the local dir causes problems
|
||||
# if the model is loaded from outside of the current working dir.
|
||||
if is_peft_available and self._hf_peft_config_loaded:
|
||||
super().save_pretrained(*args, **kwargs)
|
||||
# Then save the base model afterwards
|
||||
self._hf_peft_config_loaded = False
|
||||
super().save_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GraniteSpeechCTCEncoder",
|
||||
"GraniteSpeechForConditionalGeneration",
|
||||
"GraniteSpeechPreTrainedModel",
|
||||
]
|
@ -0,0 +1,104 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Processor class for Granite Speech."""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils import PreTokenizedInput, TextInput
|
||||
from ...utils import is_torch_available, logging
|
||||
from ...utils.import_utils import requires_backends
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GraniteSpeechProcessor(ProcessorMixin):
|
||||
attributes = ["audio_processor", "tokenizer"]
|
||||
valid_kwargs = ["audio_token"]
|
||||
|
||||
audio_processor_class = "GraniteSpeechFeatureExtractor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_processor,
|
||||
tokenizer,
|
||||
audio_token="<|audio|>",
|
||||
chat_template=None,
|
||||
):
|
||||
self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token
|
||||
super().__init__(audio_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
||||
audio: Union["torch.Tensor", List["torch.Tensor"]] = None,
|
||||
device: str = "cpu",
|
||||
images=None,
|
||||
videos=None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
text = self._get_validated_text(text)
|
||||
prompt_strings = text
|
||||
|
||||
if audio is not None:
|
||||
# NOTE - we intentionally avoid throwing for potentially misaligned
|
||||
# text / audio inputs here because some inference engines will
|
||||
# trigger the conditions due to the way they call multimodal
|
||||
# processors, e.g., vLLM.
|
||||
audio_inputs = self.audio_processor(audio, device=device)
|
||||
|
||||
# TODO (@alex-jw-brooks); we should add a util to get_num_audio_tokens
|
||||
# from feature lengths and call it here, rather than returning it
|
||||
# from the feature extractor.
|
||||
audio_embed_sizes = audio_inputs.pop("audio_embed_sizes")
|
||||
|
||||
# Expand the audio placeholders to match the feature dims; this
|
||||
# is similar to how many VLMs handle image tokens, e.g., llava next
|
||||
prompt_strings = []
|
||||
num_replaced = 0
|
||||
for sample in text:
|
||||
while self.audio_token in sample:
|
||||
sample = sample.replace(
|
||||
self.audio_token,
|
||||
"<placeholder>" * audio_embed_sizes[num_replaced],
|
||||
1,
|
||||
)
|
||||
num_replaced += 1
|
||||
prompt_strings.append(sample)
|
||||
|
||||
prompt_strings = [sample.replace("<placeholder>", self.audio_token) for sample in prompt_strings]
|
||||
else:
|
||||
audio_inputs = {}
|
||||
|
||||
text_inputs = self.tokenizer(prompt_strings, padding=True, **kwargs)
|
||||
return BatchFeature(data={**text_inputs, **audio_inputs})
|
||||
|
||||
def _get_validated_text(self, text: Union[str, list]) -> List[str]:
|
||||
if isinstance(text, str):
|
||||
return [text]
|
||||
elif isinstance(text, list) and isinstance(text[0], str):
|
||||
return text
|
||||
raise TypeError("Invalid text provided! Text should be a string or list of strings.")
|
||||
|
||||
|
||||
__all__ = ["GraniteSpeechProcessor"]
|
@ -2,6 +2,20 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class GraniteSpeechFeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["torchaudio"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchaudio"])
|
||||
|
||||
|
||||
class GraniteSpeechProcessor(metaclass=DummyObject):
|
||||
_backends = ["torchaudio"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchaudio"])
|
||||
|
||||
|
||||
class MusicgenMelodyFeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["torchaudio"]
|
||||
|
||||
|
@ -1659,6 +1659,12 @@ class GenerationTesterMixin:
|
||||
inputs_dict.pop("pixel_values", None)
|
||||
inputs_dict.pop("pixel_values_videos", None)
|
||||
inputs_dict.pop("pixel_values_images", None)
|
||||
# HACK - in the case of granite speech, input_features and inputs_embeds are mutually exclusive;
|
||||
# this is similar to VLMs and should likely be standardized for similar audio models in the future,
|
||||
# then made generic here.
|
||||
if "granitespeech" in model_class.__name__.lower():
|
||||
inputs_dict.pop("input_features", None)
|
||||
|
||||
# 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
|
||||
has_complex_embeds_computation = any(
|
||||
model_name in model_class.__name__.lower() for model_name in ["moshi"]
|
||||
|
0
tests/models/granite_speech/__init__.py
Normal file
0
tests/models/granite_speech/__init__.py
Normal file
393
tests/models/granite_speech/test_modeling_granite_speech.py
Normal file
393
tests/models/granite_speech/test_modeling_granite_speech.py
Normal file
@ -0,0 +1,393 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the IBM Granite Speech model."""
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
GraniteSpeechConfig,
|
||||
GraniteSpeechForConditionalGeneration,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_torch,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_datasets_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
class GraniteSpeechForConditionalGenerationModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
seq_length=7,
|
||||
encoder_config={
|
||||
"model_type": "granite_speech_encoder",
|
||||
"context_size": 200,
|
||||
"conv_expansion_factor": 2,
|
||||
"conv_kernel_size": 15,
|
||||
"dim_head": 32,
|
||||
"dropout": 0.1,
|
||||
"feedforward_mult": 4,
|
||||
"hidden_dim": 32,
|
||||
"input_dim": 160,
|
||||
"num_heads": 4,
|
||||
"num_layers": 2,
|
||||
"output_dim": 42,
|
||||
},
|
||||
text_config={
|
||||
"model_type": "granite",
|
||||
"is_training": True,
|
||||
"seq_length": 7,
|
||||
"use_token_type_ids": False,
|
||||
"use_labels": True,
|
||||
"vocab_size": 99,
|
||||
"hidden_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 37,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"max_position_embeddings": 580,
|
||||
"type_vocab_size": 16,
|
||||
"type_sequence_label_size": 2,
|
||||
"initializer_range": 0.02,
|
||||
"num_labels": 3,
|
||||
"num_choices": 4,
|
||||
"pad_token_id": 1,
|
||||
},
|
||||
projector_config={
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"cross_attention_frequency": 1,
|
||||
"encoder_hidden_size": 32,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 32,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 256,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 2048,
|
||||
"model_type": "blip_2_qformer",
|
||||
"num_attention_heads": 4,
|
||||
"num_hidden_layers": 2,
|
||||
"position_embedding_type": "absolute",
|
||||
"use_qformer_text_input": False,
|
||||
"vocab_size": 30522,
|
||||
},
|
||||
audio_token_index=0,
|
||||
tie_word_embeddings=True,
|
||||
initializer_range=0.02,
|
||||
has_lora_adapter=True,
|
||||
downsample_rate=5,
|
||||
window_size=15,
|
||||
is_training=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.encoder_config = encoder_config
|
||||
self.text_config = text_config
|
||||
self.projector_config = projector_config
|
||||
self.audio_token_index = audio_token_index
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.has_lora_adapater = has_lora_adapter
|
||||
self.downsample_rate = downsample_rate
|
||||
self.window_size = window_size
|
||||
self.is_training = is_training
|
||||
|
||||
# Dims for audio features
|
||||
self.sequence_dim = 844
|
||||
self.feature_dim = 160
|
||||
self.num_attention_heads = text_config["num_attention_heads"]
|
||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||
self.hidden_size = text_config["hidden_size"]
|
||||
self.batch_size = 3
|
||||
self.pad_token_id = text_config["pad_token_id"]
|
||||
self.seq_len = 7
|
||||
self.num_audio_tokens = 2
|
||||
self.seq_length = seq_length + self.num_audio_tokens
|
||||
|
||||
def get_config(self):
|
||||
return GraniteSpeechConfig(
|
||||
encoder_config=self.encoder_config,
|
||||
text_config=self.text_config,
|
||||
projector_config=self.projector_config,
|
||||
audio_token_index=self.audio_token_index,
|
||||
tie_word_embeddings=self.tie_word_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
has_lora_adapter=self.has_lora_adapater,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_features = floats_tensor(
|
||||
[self.batch_size, self.sequence_dim, self.feature_dim],
|
||||
)
|
||||
config = self.get_config()
|
||||
return config, input_features
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_features = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||
input_ids[input_ids == config.audio_token_index] = self.pad_token_id
|
||||
|
||||
input_ids[:, : self.num_audio_tokens] = config.audio_token_index
|
||||
|
||||
inputs_dict = {
|
||||
"input_features": input_features,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_granite_speech_model_fp16_forward(self, config, input_ids, input_features, attention_mask):
|
||||
model = GraniteSpeechForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_features=input_features,
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
def create_and_check_granite_speech_model_fp16_autocast_forward(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_features,
|
||||
attention_mask,
|
||||
):
|
||||
config.torch_dtype = torch.float16
|
||||
model = GraniteSpeechForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_features=input_features.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class GraniteSpeechForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `GraniteSpeechForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
all_model_classes = (GraniteSpeechForConditionalGeneration,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = GraniteSpeechForConditionalGenerationModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self,
|
||||
config_class=GraniteSpeechConfig,
|
||||
has_text_modality=False,
|
||||
)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
# overwrite inputs_embeds tests because we need to delete "input features" for the audio model
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["input_features"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if name == "projector.query":
|
||||
continue
|
||||
elif param.requires_grad:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
# overwrite because Granite Speech is audio+text model (not vision+text)
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if not self._is_composite:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# NOTE - currently we only enable alternate attention implementations on
|
||||
# the encapsulated LLM; in the future, this should be added for the conformer
|
||||
# encoder as well.
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
|
||||
|
||||
# `None` as it is the requested one which will be assigned to each sub-config
|
||||
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
|
||||
|
||||
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.language_model.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
|
||||
class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# TODO - use the actual model path on HF hub after release.
|
||||
self.model_path = "ibm-granite/granite-speech"
|
||||
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
||||
self.prompt = self._get_prompt(self.processor.tokenizer)
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def _get_prompt(self, tokenizer):
|
||||
chat = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|audio|>can you transcribe the speech into a written format?",
|
||||
},
|
||||
]
|
||||
return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
@slow
|
||||
@pytest.mark.skip("Public models not yet available")
|
||||
def test_small_model_integration_test_single(self):
|
||||
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
# Verify feature sizes; note that the feature mask refers to the size of
|
||||
# features that are masked into the LLM, not the output of the processor,
|
||||
# which is why we inspect the mask instead of the `num_features` tensor.
|
||||
inputs = self.processor(self.prompt, input_speech, return_tensors="pt").to(torch_device)
|
||||
|
||||
num_computed_features = self.processor.audio_processor._get_num_audio_features(
|
||||
[speech_arr.shape[-1] for speech_arr in input_speech],
|
||||
)[0]
|
||||
num_actual_features = torch.sum(inputs["input_features_mask"]).item()
|
||||
assert num_actual_features == num_computed_features
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=32)
|
||||
EXPECTED_DECODED_TEXT = "systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantmister quilter is the apostle of the middle classes and we are glad to welcome his gospel" # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.tokenizer.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@pytest.mark.skip("Public models not yet available")
|
||||
def test_small_model_integration_test_batch(self):
|
||||
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path)
|
||||
input_speech = self._load_datasamples(2)
|
||||
prompts = [self.prompt, self.prompt]
|
||||
|
||||
# Verify feature sizes & padding
|
||||
inputs = self.processor(prompts, input_speech, return_tensors="pt").to(model.device)
|
||||
num_computed_features = self.processor.audio_processor._get_num_audio_features(
|
||||
[speech_arr.shape[-1] for speech_arr in input_speech],
|
||||
)
|
||||
num_actual_features = torch.sum(inputs["input_features_mask"], dim=-1)
|
||||
for e_feats, a_feats in zip(num_computed_features, num_actual_features):
|
||||
assert e_feats == a_feats.item()
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=32)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantmister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
||||
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilter's manner less interesting than his matter"
|
||||
] # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.tokenizer.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
222
tests/models/granite_speech/test_processor_granite_speech.py
Normal file
222
tests/models/granite_speech/test_processor_granite_speech.py
Normal file
@ -0,0 +1,222 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torchaudio,
|
||||
)
|
||||
from transformers.utils import is_torchaudio_available
|
||||
|
||||
|
||||
if is_torchaudio_available():
|
||||
from transformers import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor
|
||||
|
||||
|
||||
@pytest.skip("Public models not yet available", allow_module_level=True)
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class GraniteSpeechProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
# TODO - use the actual model path on HF hub after release.
|
||||
self.checkpoint = "ibm-granite/granite-speech"
|
||||
processor = GraniteSpeechProcessor.from_pretrained(self.checkpoint)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoTokenizer.from_pretrained(self.checkpoint, **kwargs)
|
||||
|
||||
def get_audio_processor(self, **kwargs):
|
||||
return GraniteSpeechFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
"""Ensure we can save / reload a processor correctly."""
|
||||
tokenizer = self.get_tokenizer()
|
||||
audio_processor = self.get_audio_processor()
|
||||
processor = GraniteSpeechProcessor(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = GraniteSpeechProcessor.from_pretrained(self.tmpdirname)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, GPT2TokenizerFast)
|
||||
|
||||
self.assertEqual(processor.audio_processor.to_json_string(), audio_processor.to_json_string())
|
||||
self.assertIsInstance(processor.audio_processor, GraniteSpeechFeatureExtractor)
|
||||
|
||||
def test_requires_text(self):
|
||||
"""Ensure we require text"""
|
||||
tokenizer = self.get_tokenizer()
|
||||
audio_processor = self.get_audio_processor()
|
||||
processor = GraniteSpeechProcessor(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
processor(text=None)
|
||||
|
||||
def test_bad_text_fails(self):
|
||||
"""Ensure we gracefully fail if text is the wrong type."""
|
||||
tokenizer = self.get_tokenizer()
|
||||
audio_processor = self.get_audio_processor()
|
||||
|
||||
processor = GraniteSpeechProcessor(tokenizer=tokenizer, audio_processor=audio_processor)
|
||||
with pytest.raises(TypeError):
|
||||
processor(text=424, audio=None)
|
||||
|
||||
def test_bad_nested_text_fails(self):
|
||||
"""Ensure we gracefully fail if text is the wrong nested type."""
|
||||
tokenizer = self.get_tokenizer()
|
||||
audio_processor = self.get_audio_processor()
|
||||
processor = GraniteSpeechProcessor(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
processor(text=[424], audio=None)
|
||||
|
||||
def test_bad_audio_fails(self):
|
||||
"""Ensure we gracefully fail if audio is the wrong type."""
|
||||
tokenizer = self.get_tokenizer()
|
||||
audio_processor = self.get_audio_processor()
|
||||
processor = GraniteSpeechProcessor(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
processor(text=None, audio="foo")
|
||||
|
||||
def test_nested_bad_audio_fails(self):
|
||||
"""Ensure we gracefully fail if audio is the wrong nested type."""
|
||||
tokenizer = self.get_tokenizer()
|
||||
audio_processor = self.get_audio_processor()
|
||||
processor = GraniteSpeechProcessor(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
processor(text=None, audio=["foo"])
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
([1, 269920], [171], torch.rand),
|
||||
([1, 269920], [171], np.random.rand),
|
||||
]
|
||||
)
|
||||
def test_audio_token_filling_same_len_feature_tensors(self, vec_dims, num_expected_features, random_func):
|
||||
"""Ensure audio token filling is handled correctly when we have
|
||||
one or more audio inputs whose features are all the same length
|
||||
stacked into a tensor / numpy array.
|
||||
|
||||
NOTE: Currently we enforce that each sample can only have one audio.
|
||||
"""
|
||||
tokenizer = self.get_tokenizer()
|
||||
audio_processor = self.get_audio_processor()
|
||||
processor = GraniteSpeechProcessor(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
)
|
||||
audio = random_func(*vec_dims) - 0.5
|
||||
|
||||
audio_tokens = processor.audio_token * vec_dims[0]
|
||||
inputs = processor(text=f"{audio_tokens} Can you compare this audio?", audio=audio, return_tensors="pt")
|
||||
|
||||
# Check the number of audio tokens
|
||||
audio_token_id = tokenizer.get_vocab()[processor.audio_token]
|
||||
|
||||
# Make sure the number of audio tokens matches the number of features
|
||||
num_computed_features = processor.audio_processor._get_num_audio_features(
|
||||
[vec_dims[1] for _ in range(vec_dims[0])],
|
||||
)
|
||||
num_audio_tokens = int(torch.sum(inputs["input_ids"] == audio_token_id))
|
||||
assert list(inputs["input_features"].shape) == [vec_dims[0], 844, 160]
|
||||
assert sum(num_computed_features) == num_audio_tokens
|
||||
|
||||
def test_audio_token_filling_varying_len_feature_list(self):
|
||||
"""Ensure audio token filling is handled correctly when we have
|
||||
multiple varying len audio sequences passed as a list.
|
||||
"""
|
||||
tokenizer = self.get_tokenizer()
|
||||
audio_processor = self.get_audio_processor()
|
||||
processor = GraniteSpeechProcessor(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
)
|
||||
vec_dims = [[1, 142100], [1, 269920]]
|
||||
num_expected_features = [90, 171]
|
||||
audio = [torch.rand(dims) - 0.5 for dims in vec_dims]
|
||||
|
||||
inputs = processor(
|
||||
text=[
|
||||
f"{processor.audio_token} Can you describe this audio?",
|
||||
f"{processor.audio_token} How does it compare with this audio?",
|
||||
],
|
||||
audio=audio,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Check the number of audio tokens
|
||||
audio_token_id = tokenizer.get_vocab()[processor.audio_token]
|
||||
|
||||
# Make sure the number of audio tokens matches the number of features
|
||||
num_calculated_features = processor.audio_processor._get_num_audio_features(
|
||||
[dims[1] for dims in vec_dims],
|
||||
)
|
||||
num_audio_tokens = int(torch.sum(inputs["input_ids"] == audio_token_id))
|
||||
assert num_calculated_features == [90, 171]
|
||||
assert sum(num_expected_features) == num_audio_tokens
|
||||
|
||||
@require_torch_gpu
|
||||
def test_device_override(self):
|
||||
"""Ensure that we regardless of the processing device, the tensors
|
||||
produced are on the CPU.
|
||||
"""
|
||||
tokenizer = self.get_tokenizer()
|
||||
audio_processor = self.get_audio_processor()
|
||||
processor = GraniteSpeechProcessor(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
)
|
||||
|
||||
vec_dims = [1, 269920]
|
||||
wav = torch.rand(vec_dims) - 0.5
|
||||
|
||||
inputs = processor(
|
||||
text=f"{processor.audio_token} Can you transcribe this audio?",
|
||||
audio=wav,
|
||||
return_tensors="pt",
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
assert inputs["input_features"].device.type == "cpu"
|
@ -48,6 +48,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"GraniteConfig",
|
||||
"GraniteMoeConfig",
|
||||
"Qwen3MoeConfig",
|
||||
"GraniteSpeechConfig",
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user