mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add Moonshine (#34784)
* config draft * full encoder forward * full decoder forward * fix sdpa and FA2 * fix sdpa and FA2 * moonshine model * moonshine model forward * fix attention with past_key_values * add MoonshineForConditionalGeneration * fix cache handling and causality for cross attention * no causal attention mask for the encoder * model addition (imports etc) * small nit * nits * Update src/transformers/models/moonshine/convert_usefulsensors_to_hf.py Co-authored-by: Joshua Lochner <admin@xenova.com> * add rope_theta * nits * model doc * Update src/transformers/models/auto/configuration_auto.py Co-authored-by: Joshua Lochner <admin@xenova.com> * imports * add MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES * updates modular * make * make fix-copies * ruff check examples fix * fix check_modular_conversion * nit * nits * nits * copied from -> imports * imports fix * integrate attention refacto * modular edge case * remove encoder * convolutions params in config * run modular_model_converter * make * Update docs/source/en/model_doc/moonshine.md Co-authored-by: Joshua Lochner <admin@xenova.com> * MoonshineModelTest * correct typo * make style * integration tests * make * modular convert * name conversion update (up_proj -> fc1 etc) * update config * update MLP * update attention * update encoder layer * update decoder layer * update convolutions parameters * update encoder * remove INPUTS_DOCSTRING * update decoder * update conditional generation * update pretrained model * imports * modular converted * update doc * fix * typo * update doc * update license * update init * split config in file * two classes for MLP * attention from GLM * from GlmRotaryEmbedding * split MLP * apply arthur's review suggestions * apply arthur's review suggestions * apply arthur's review suggestions * auto feature extractor * convert modular * fix + make * convert modular * make * unsplit config * use correct checkpoint * wrap generate * update tests * typos * make * typo * update doc --------- Co-authored-by: Joshua Lochner <admin@xenova.com>
This commit is contained in:
parent
6f127d3f81
commit
5f087d1335
@ -505,7 +505,9 @@
|
||||
- local: model_doc/mobilebert
|
||||
title: MobileBERT
|
||||
- local: model_doc/modernbert
|
||||
title: ModernBERT
|
||||
title: ModernBert
|
||||
- local: model_doc/moonshine
|
||||
title: moonshine
|
||||
- local: model_doc/mpnet
|
||||
title: MPNet
|
||||
- local: model_doc/mpt
|
||||
|
@ -235,6 +235,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ |
|
||||
| [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ |
|
||||
| [ModernBERT](model_doc/modernbert) | ✅ | ❌ | ❌ |
|
||||
| [Moonshine](model_doc/moonshine) | ✅ | ❌ | ❌ |
|
||||
| [Moshi](model_doc/moshi) | ✅ | ❌ | ❌ |
|
||||
| [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ |
|
||||
| [MPT](model_doc/mpt) | ✅ | ❌ | ❌ |
|
||||
|
56
docs/source/en/model_doc/moonshine.md
Normal file
56
docs/source/en/model_doc/moonshine.md
Normal file
@ -0,0 +1,56 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Moonshine
|
||||
|
||||
## Overview
|
||||
|
||||
The Moonshine model was proposed in [Moonshine: Speech Recognition for Live Transcription and Voice Commands
|
||||
](https://arxiv.org/abs/2410.15608) by Nat Jeffries, Evan King, Manjunath Kudlur, Guy Nicholson, James Wang, Pete Warden.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*This paper introduces Moonshine, a family of speech recognition models optimized for live transcription and voice command processing. Moonshine is based on an encoder-decoder transformer architecture and employs Rotary Position Embedding (RoPE) instead of traditional absolute position embeddings. The model is trained on speech segments of various lengths, but without using zero-padding, leading to greater efficiency for the encoder during inference time. When benchmarked against OpenAI's Whisper tiny-en, Moonshine Tiny demonstrates a 5x reduction in compute requirements for transcribing a 10-second speech segment while incurring no increase in word error rates across standard evaluation datasets. These results highlight Moonshine's potential for real-time and resource-constrained applications.*
|
||||
|
||||
Tips:
|
||||
|
||||
- Moonshine improves upon Whisper's architecture:
|
||||
1. It uses SwiGLU activation instead of GELU in the decoder layers
|
||||
2. Most importantly, it replaces absolute position embeddings with Rotary Position Embeddings (RoPE). This allows Moonshine to handle audio inputs of any length, unlike Whisper which is restricted to fixed 30-second windows.
|
||||
|
||||
This model was contributed by [Eustache Le Bihan (eustlb)](https://huggingface.co/eustlb).
|
||||
The original code can be found [here](https://github.com/usefulsensors/moonshine).
|
||||
|
||||
## Resources
|
||||
|
||||
- [Automatic speech recognition task guide](../tasks/asr)
|
||||
|
||||
## MoonshineConfig
|
||||
|
||||
[[autodoc]] MoonshineConfig
|
||||
|
||||
## MoonshineModel
|
||||
|
||||
[[autodoc]] MoonshineModel
|
||||
- forward
|
||||
- _mask_input_features
|
||||
|
||||
## MoonshineForConditionalGeneration
|
||||
|
||||
[[autodoc]] MoonshineForConditionalGeneration
|
||||
- forward
|
||||
- generate
|
||||
|
@ -68,6 +68,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
|
||||
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
|
||||
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
|
||||
* [Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine#transformers.MoonshineModel)
|
||||
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
|
||||
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
|
||||
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
|
||||
@ -265,6 +266,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
|
||||
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
|
||||
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100#transformers.M2M100Model)
|
||||
* [Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine#transformers.MoonshineModel)
|
||||
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration)
|
||||
@ -283,8 +285,8 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
|
||||
* [PhiMoE](https://huggingface.co/docs/transformers/model_doc/phimoe#transformers.PhimoeModel)
|
||||
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
|
||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||
* [mBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
||||
* [Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine#transformers.MoonshineModel)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||
|
@ -610,6 +610,7 @@ _import_structure = {
|
||||
"models.mobilevit": ["MobileViTConfig"],
|
||||
"models.mobilevitv2": ["MobileViTV2Config"],
|
||||
"models.modernbert": ["ModernBertConfig"],
|
||||
"models.moonshine": ["MoonshineConfig"],
|
||||
"models.moshi": [
|
||||
"MoshiConfig",
|
||||
"MoshiDepthConfig",
|
||||
@ -2907,6 +2908,13 @@ else:
|
||||
"ModernBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.moonshine"].extend(
|
||||
[
|
||||
"MoonshineForConditionalGeneration",
|
||||
"MoonshineModel",
|
||||
"MoonshinePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.moshi"].extend(
|
||||
[
|
||||
"MoshiForCausalLM",
|
||||
@ -5633,6 +5641,7 @@ if TYPE_CHECKING:
|
||||
MobileViTV2Config,
|
||||
)
|
||||
from .models.modernbert import ModernBertConfig
|
||||
from .models.moonshine import MoonshineConfig
|
||||
from .models.moshi import (
|
||||
MoshiConfig,
|
||||
MoshiDepthConfig,
|
||||
@ -7652,6 +7661,11 @@ if TYPE_CHECKING:
|
||||
ModernBertModel,
|
||||
ModernBertPreTrainedModel,
|
||||
)
|
||||
from .models.moonshine import (
|
||||
MoonshineForConditionalGeneration,
|
||||
MoonshineModel,
|
||||
MoonshinePreTrainedModel,
|
||||
)
|
||||
from .models.moshi import (
|
||||
MoshiForCausalLM,
|
||||
MoshiForConditionalGeneration,
|
||||
|
@ -170,6 +170,7 @@ from . import (
|
||||
mobilevit,
|
||||
mobilevitv2,
|
||||
modernbert,
|
||||
moonshine,
|
||||
moshi,
|
||||
mpnet,
|
||||
mpt,
|
||||
|
@ -190,6 +190,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevit", "MobileViTConfig"),
|
||||
("mobilevitv2", "MobileViTV2Config"),
|
||||
("modernbert", "ModernBertConfig"),
|
||||
("moonshine", "MoonshineConfig"),
|
||||
("moshi", "MoshiConfig"),
|
||||
("mpnet", "MPNetConfig"),
|
||||
("mpt", "MptConfig"),
|
||||
@ -519,6 +520,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mobilevit", "MobileViT"),
|
||||
("mobilevitv2", "MobileViTV2"),
|
||||
("modernbert", "ModernBERT"),
|
||||
("moonshine", "Moonshine"),
|
||||
("moshi", "Moshi"),
|
||||
("mpnet", "MPNet"),
|
||||
("mpt", "MPT"),
|
||||
|
@ -73,6 +73,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("mobilenet_v1", "MobileNetV1FeatureExtractor"),
|
||||
("mobilenet_v2", "MobileNetV2FeatureExtractor"),
|
||||
("mobilevit", "MobileViTFeatureExtractor"),
|
||||
("moonshine", "Wav2Vec2FeatureExtractor"),
|
||||
("moshi", "EncodecFeatureExtractor"),
|
||||
("nat", "ViTFeatureExtractor"),
|
||||
("owlvit", "OwlViTFeatureExtractor"),
|
||||
|
@ -179,6 +179,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevit", "MobileViTModel"),
|
||||
("mobilevitv2", "MobileViTV2Model"),
|
||||
("modernbert", "ModernBertModel"),
|
||||
("moonshine", "MoonshineModel"),
|
||||
("moshi", "MoshiModel"),
|
||||
("mpnet", "MPNetModel"),
|
||||
("mpt", "MptModel"),
|
||||
@ -436,6 +437,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("mega", "MegaForMaskedLM"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
("mobilebert", "MobileBertForMaskedLM"),
|
||||
("moonshine", "MoonshineForConditionalGeneration"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("mpt", "MptForCausalLM"),
|
||||
("mra", "MraForMaskedLM"),
|
||||
@ -937,6 +939,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("moonshine", "MoonshineForConditionalGeneration"),
|
||||
("pop2piano", "Pop2PianoForConditionalGeneration"),
|
||||
("seamless_m4t", "SeamlessM4TForSpeechToText"),
|
||||
("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
|
||||
|
@ -81,6 +81,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("mctct", "MCTCTProcessor"),
|
||||
("mgp-str", "MgpstrProcessor"),
|
||||
("mllama", "MllamaProcessor"),
|
||||
("moonshine", "Wav2Vec2Processor"),
|
||||
("oneformer", "OneFormerProcessor"),
|
||||
("owlv2", "Owlv2Processor"),
|
||||
("owlvit", "OwlViTProcessor"),
|
||||
|
@ -321,6 +321,7 @@ else:
|
||||
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
27
src/transformers/models/moonshine/__init__.py
Normal file
27
src/transformers/models/moonshine/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_moonshine import *
|
||||
from .modeling_moonshine import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
224
src/transformers/models/moonshine/configuration_moonshine.py
Normal file
224
src/transformers/models/moonshine/configuration_moonshine.py
Normal file
@ -0,0 +1,224 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/moonshine/modular_moonshine.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_moonshine.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class MoonshineConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MoonshineModel`]. It is used to instantiate a Moonshine
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Moonshine
|
||||
[UsefulSensors/moonshine-tiny](https://huggingface.co/UsefulSensors/moonshine-tiny).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32768):
|
||||
Vocabulary size of the Moonshine model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`MoonshineModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 288):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 1152):
|
||||
Dimension of the MLP representations.
|
||||
encoder_num_hidden_layers (`int`, *optional*, defaults to 6):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
decoder_num_hidden_layers (`int`, *optional*, defaults to 6):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
encoder_num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
encoder_num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`encoder_num_key_value_heads=encoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`encoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
decoder_num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`decoder_num_key_value_heads=decoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`decoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`decoder_num_attention_heads`.
|
||||
encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder.
|
||||
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
decoder_start_token_id (`int`, *optional*, defaults to 1):
|
||||
Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids`
|
||||
are provided to the `generate` function. It is used to guide the model`s generation process depending on
|
||||
the task.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
partial_rotary_factor (`float`, *optional*, defaults to 0.9):
|
||||
Percentage of the query and keys which will have rotary embedding.
|
||||
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
||||
Whether the model is used as an encoder/decoder or not.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Denotes beginning of sequences token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
Denotes end of sequences token id.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MoonshineModel, MoonshineConfig
|
||||
|
||||
>>> # Initializing a Moonshine style configuration
|
||||
>>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine-tiny")
|
||||
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = MoonshineModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "moonshine"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"num_key_value_heads": "encoder_num_key_value_heads",
|
||||
"num_attention_heads": "encoder_num_attention_heads",
|
||||
"num_hidden_layers": "encoder_num_hidden_layers",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32768,
|
||||
hidden_size=288,
|
||||
intermediate_size=1152,
|
||||
encoder_num_hidden_layers=6,
|
||||
decoder_num_hidden_layers=6,
|
||||
encoder_num_attention_heads=8,
|
||||
decoder_num_attention_heads=8,
|
||||
encoder_num_key_value_heads=None,
|
||||
decoder_num_key_value_heads=None,
|
||||
encoder_hidden_act="gelu",
|
||||
decoder_hidden_act="silu",
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
decoder_start_token_id=1,
|
||||
use_cache=True,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
partial_rotary_factor=0.9,
|
||||
is_encoder_decoder=True,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.encoder_num_hidden_layers = encoder_num_hidden_layers
|
||||
self.decoder_num_hidden_layers = decoder_num_hidden_layers
|
||||
self.encoder_num_attention_heads = encoder_num_attention_heads
|
||||
self.decoder_num_attention_heads = decoder_num_attention_heads
|
||||
|
||||
if encoder_num_key_value_heads is None:
|
||||
encoder_num_key_value_heads = encoder_num_attention_heads
|
||||
self.encoder_num_key_value_heads = encoder_num_key_value_heads
|
||||
|
||||
if decoder_num_key_value_heads is None:
|
||||
decoder_num_key_value_heads = decoder_num_attention_heads
|
||||
self.decoder_num_key_value_heads = decoder_num_key_value_heads
|
||||
|
||||
self.encoder_hidden_act = encoder_hidden_act
|
||||
self.decoder_hidden_act = decoder_hidden_act
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self.is_encoder_decoder = is_encoder_decoder
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["MoonshineConfig"]
|
169
src/transformers/models/moonshine/convert_usefulsensors_to_hf.py
Normal file
169
src/transformers/models/moonshine/convert_usefulsensors_to_hf.py
Normal file
@ -0,0 +1,169 @@
|
||||
# Copyright 2025 Useful Sensors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers.models.moonshine.modeling_moonshine import MoonshineConfig, MoonshineForConditionalGeneration
|
||||
|
||||
|
||||
# Copied from https://github.com/usefulsensors/moonshine/blob/a1d77cc573b0471ac4602b86f67b3f48d67df1a9/moonshine/model.py
|
||||
def _get_weights(model_name):
|
||||
repo = "UsefulSensors/moonshine"
|
||||
|
||||
return (
|
||||
hf_hub_download(repo, f"{x}.weights.h5", subfolder=model_name) for x in ("preprocessor", "encoder", "decoder")
|
||||
)
|
||||
|
||||
|
||||
def _read_h5_weights(group, current_key="", weights={}):
|
||||
for key in group.keys():
|
||||
full_key = f"{current_key}.{key}" if current_key else key
|
||||
if isinstance(group[key], h5py.Dataset):
|
||||
w = np.array(group[key])
|
||||
w = torch.from_numpy(w)
|
||||
if len(w.shape) > 1:
|
||||
if len(w.shape) == 3:
|
||||
hidden_size = max(list(w.shape))
|
||||
try:
|
||||
w = w.reshape(hidden_size, hidden_size)
|
||||
except RuntimeError:
|
||||
# meaning its a conv layers
|
||||
pass
|
||||
w = w.transpose(0, -1)
|
||||
weights[full_key] = w
|
||||
else:
|
||||
_read_h5_weights(group[key], full_key, weights)
|
||||
return weights
|
||||
|
||||
|
||||
def _convert_layer_names(name, gated_mlp=False):
|
||||
name = re.sub(
|
||||
r"layers\.functional(?:_(\d+))?\.layers",
|
||||
lambda m: f'layers.{m.group(1) if m.group(1) else "0"}',
|
||||
name,
|
||||
count=1,
|
||||
)
|
||||
if gated_mlp:
|
||||
name = re.sub(r"functional\.layers\.dense\.", "mlp.fc1.", name)
|
||||
name = re.sub(r"functional\.layers\.dense_1\.", "mlp.fc2.", name)
|
||||
else:
|
||||
name = re.sub(r"functional\.layers\.sequential\.layers\.dense\.", "mlp.fc1.", name)
|
||||
name = re.sub(r"functional\.layers\.sequential\.layers\.dense_1\.", "mlp.fc2.", name)
|
||||
name = re.sub(r"layers\.sequential\.layers\.conv1d\.", "conv1.", name)
|
||||
name = re.sub(r"layers\.sequential\.layers\.conv1d_1\.", "conv2.", name)
|
||||
name = re.sub(r"layers\.sequential\.layers\.conv1d_2\.", "conv3.", name)
|
||||
name = re.sub(r"layers\.sequential\.layers\.group_normalization\.", "groupnorm.", name)
|
||||
name = re.sub(r"mha_with_rope\.key_dense", "self_attn.k_proj", name)
|
||||
name = re.sub(r"mha_with_rope\.query_dense", "self_attn.q_proj", name)
|
||||
name = re.sub(r"mha_with_rope\.value_dense", "self_attn.v_proj", name)
|
||||
name = re.sub(r"mha_with_rope\.output_dense", "self_attn.o_proj", name)
|
||||
name = re.sub(r"mha_precomputed_kv\.key_dense", "encoder_attn.k_proj", name)
|
||||
name = re.sub(r"mha_precomputed_kv\.query_dense", "encoder_attn.q_proj", name)
|
||||
name = re.sub(r"mha_precomputed_kv\.value_dense", "encoder_attn.v_proj", name)
|
||||
name = re.sub(r"mha_precomputed_kv\.output_dense", "encoder_attn.o_proj", name)
|
||||
name = re.sub(r"mha_causal_with_rope\.key_dense", "self_attn.k_proj", name)
|
||||
name = re.sub(r"mha_causal_with_rope\.query_dense", "self_attn.q_proj", name)
|
||||
name = re.sub(r"mha_causal_with_rope\.value_dense", "self_attn.v_proj", name)
|
||||
name = re.sub(r"mha_causal_with_rope\.output_dense", "self_attn.o_proj", name)
|
||||
name = re.sub(r"layer_normalization\.", "input_layernorm.", name)
|
||||
name = re.sub(r"layer_normalization_1\.", "post_attention_layernorm.", name)
|
||||
name = re.sub(r"layer_normalization_2\.", "final_layernorm.", name)
|
||||
name = re.sub(r"vars\.0", "weight", name)
|
||||
name = re.sub(r"vars\.1", "bias", name)
|
||||
name = re.sub(r"layers\.reversible_embedding", "embed_tokens", name)
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def _convert_weights(weights, encoder=True):
|
||||
if "layers.rotary_embedding.vars.0" in weights:
|
||||
weights.pop("layers.rotary_embedding.vars.0")
|
||||
|
||||
converted_weights = {}
|
||||
if encoder:
|
||||
converted_weights["layer_norm.weight"] = weights.pop("layers.layer_normalization.vars.0")
|
||||
else:
|
||||
converted_weights["norm.weight"] = weights.pop("layers.layer_normalization.vars.0")
|
||||
|
||||
for name, w in weights.items():
|
||||
if encoder:
|
||||
new_name = _convert_layer_names(name)
|
||||
else:
|
||||
new_name = _convert_layer_names(name, gated_mlp=True)
|
||||
converted_weights[new_name] = w
|
||||
|
||||
return converted_weights
|
||||
|
||||
|
||||
def convert_usefulsensors_moonshine_to_hf(model_name, pytorch_dump_folder_path):
|
||||
preprocessor_weights_path, encoder_weights_path, decoder_weights_path = _get_weights(model_name)
|
||||
|
||||
with h5py.File(preprocessor_weights_path, "r") as f:
|
||||
loaded_preprocessor_weights = _read_h5_weights(f, weights={})
|
||||
|
||||
with h5py.File(encoder_weights_path, "r") as f:
|
||||
loaded_encoder_weights = _read_h5_weights(f, weights={})
|
||||
|
||||
with h5py.File(decoder_weights_path, "r") as f:
|
||||
loaded_decoder_weights = _read_h5_weights(f, weights={})
|
||||
|
||||
encoder_state_dict = {**loaded_encoder_weights, **loaded_preprocessor_weights}
|
||||
converted_encoder_state_dict = _convert_weights(encoder_state_dict)
|
||||
|
||||
converted_decoder_state_dict = _convert_weights(loaded_decoder_weights, encoder=False)
|
||||
converted_decoder_state_dict["embed_tokens.weight"] = converted_decoder_state_dict["embed_tokens.weight"].T
|
||||
|
||||
final_weights = {}
|
||||
for k, v in converted_encoder_state_dict.items():
|
||||
final_weights[f"model.encoder.{k}"] = v
|
||||
|
||||
for k, v in converted_decoder_state_dict.items():
|
||||
final_weights[f"model.decoder.{k}"] = v
|
||||
|
||||
if model_name == "tiny":
|
||||
config = MoonshineConfig()
|
||||
elif model_name == "base":
|
||||
config = MoonshineConfig(
|
||||
hidden_size=416,
|
||||
intermediate_size=1664,
|
||||
encoder_num_hidden_layers=8,
|
||||
decoder_num_hidden_layers=8,
|
||||
encoder_num_attention_heads=8,
|
||||
decoder_num_attention_heads=8,
|
||||
partial_rotary_factor=0.62,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}")
|
||||
|
||||
final_weights["proj_out.weight"] = converted_decoder_state_dict["embed_tokens.weight"]
|
||||
|
||||
model = MoonshineForConditionalGeneration(config)
|
||||
model.load_state_dict(final_weights)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# # Required parameters
|
||||
parser.add_argument("--model_name", type=str, help="Path to the downloaded checkpoints")
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_usefulsensors_moonshine_to_hf(args.model_name, args.pytorch_dump_folder_path)
|
1570
src/transformers/models/moonshine/modeling_moonshine.py
Normal file
1570
src/transformers/models/moonshine/modeling_moonshine.py
Normal file
File diff suppressed because it is too large
Load Diff
1135
src/transformers/models/moonshine/modular_moonshine.py
Normal file
1135
src/transformers/models/moonshine/modular_moonshine.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -6523,6 +6523,27 @@ class ModernBertPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MoonshineForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MoonshineModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MoonshinePreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MoshiForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
0
tests/models/moonshine/__init__.py
Normal file
0
tests/models/moonshine/__init__.py
Normal file
620
tests/models/moonshine/test_modeling_moonshine.py
Normal file
620
tests/models/moonshine/test_modeling_moonshine.py
Normal file
@ -0,0 +1,620 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Moonshine model."""
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from transformers import MoonshineConfig, is_torch_available
|
||||
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
floats_tensor,
|
||||
random_attention_mask,
|
||||
)
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
MoonshineForConditionalGeneration,
|
||||
MoonshineModel,
|
||||
)
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
class MoonshineModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3, # need batch_size != num_hidden_layers
|
||||
seq_length=1000,
|
||||
is_training=False,
|
||||
use_labels=False,
|
||||
vocab_size=147,
|
||||
hidden_size=8,
|
||||
intermediate_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
encoder_hidden_act="gelu",
|
||||
decoder_hidden_act="silu",
|
||||
decoder_start_token_id=85,
|
||||
bos_token_id=98,
|
||||
eos_token_id=98,
|
||||
pad_token_id=0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.encoder_hidden_act = encoder_hidden_act
|
||||
self.decoder_hidden_act = decoder_hidden_act
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = floats_tensor([self.batch_size, self.seq_length], scale=1.0)
|
||||
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
decoder_input_ids = torch.tensor(self.batch_size * [[self.decoder_start_token_id]], device=torch_device)
|
||||
decoder_attention_mask = decoder_input_ids.ne(self.pad_token_id)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_values, attention_mask, decoder_input_ids, decoder_attention_mask
|
||||
|
||||
def get_config(self):
|
||||
return MoonshineConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
encoder_num_hidden_layers=self.num_hidden_layers,
|
||||
decoder_num_hidden_layers=self.num_hidden_layers,
|
||||
encoder_num_attention_heads=self.num_attention_heads,
|
||||
decoder_num_attention_heads=self.num_attention_heads,
|
||||
encoder_num_key_value_heads=self.num_key_value_heads,
|
||||
decoder_num_key_value_heads=self.num_key_value_heads,
|
||||
encoder_hidden_act=self.encoder_hidden_act,
|
||||
decoder_hidden_act=self.decoder_hidden_act,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
model = MoonshineModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_values, attention_mask=attention_mask)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_batch_inference(self, config, input_values, *args):
|
||||
# test does not pass for models making use of `group_norm`
|
||||
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||
model = MoonshineModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0.0
|
||||
|
||||
batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state
|
||||
|
||||
for i in range(input_values.shape[0]):
|
||||
input_slice = input_values[i : i + 1, : input_lengths[i]]
|
||||
output = model(input_slice).last_hidden_state
|
||||
|
||||
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
|
||||
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
|
||||
|
||||
def check_output_attentions(self, config, input_values, attention_mask):
|
||||
model = MoonshineModel(config=config)
|
||||
model.config.layerdrop = 1.0
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
outputs = model(input_values, attention_mask=attention_mask, output_attentions=True)
|
||||
self.parent.assertTrue(len(outputs.attentions) > 0)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, input_values, attention_mask, decoder_input_ids, decoder_attention_mask = (
|
||||
self.prepare_config_and_inputs()
|
||||
)
|
||||
inputs_dict = {
|
||||
"input_values": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class MoonshineModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (MoonshineModel, MoonshineForConditionalGeneration) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"automatic-speech-recognition": MoonshineForConditionalGeneration,
|
||||
"feature-extraction": MoonshineModel,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MoonshineModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MoonshineConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", 1)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
decoder_key_length = getattr(self.model_tester, "decoder_key_length", 1)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
subsampled_encoder_seq_length = model._get_feat_extract_output_lengths(encoder_seq_length)
|
||||
subsampled_encoder_key_length = model._get_feat_extract_output_lengths(encoder_key_length)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
correct_outlen = 5
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
decoder_seq_length,
|
||||
subsampled_encoder_key_length,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
added_hidden_states = 2
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
|
||||
)
|
||||
|
||||
# Copied from tests.models.whisper.test_modeling_whisper.WhisperModelTest.test_hidden_states_output
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
if hasattr(self.model_tester, "encoder_seq_length"):
|
||||
seq_length = self.model_tester.encoder_seq_length
|
||||
else:
|
||||
seq_length = self.model_tester.seq_length
|
||||
|
||||
subsampled_seq_length = model._get_feat_extract_output_lengths(seq_length)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[subsampled_seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
|
||||
self.assertIsInstance(hidden_states, (list, tuple))
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", 1)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[decoder_seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# Copied from tests.models.whisper.test_modeling_whisper.WhisperModelTest.test_inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
decoder_input_ids = inputs.pop("decoder_input_ids", None)
|
||||
inputs.pop("decoder_attention_mask", None)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# Copied from tests.models.whisper.test_modeling_whisper.WhisperModelTest.test_resize_tokens_embeddings
|
||||
def test_resize_tokens_embeddings(self):
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is False")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
if self.model_tester.is_training is False:
|
||||
model.eval()
|
||||
|
||||
model_vocab_size = config.vocab_size
|
||||
# Retrieve the embeddings and clone theme
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||
cloned_embeddings = model_embed.weight.clone()
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
||||
|
||||
# make sure that decoder_input_ids are resized
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||
models_equal = True
|
||||
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Copied from tests.models.whisper.test_modeling_whisper.WhisperModelTest.test_resize_embeddings_untied
|
||||
def test_resize_embeddings_untied(self):
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is False")
|
||||
|
||||
original_config.tie_word_embeddings = False
|
||||
|
||||
# if model cannot untied embeddings -> leave test
|
||||
if original_config.tie_word_embeddings:
|
||||
self.skipTest(reason="Model cannot untie embeddings")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config).to(torch_device)
|
||||
|
||||
# if no output embeddings -> leave test
|
||||
if model.get_output_embeddings() is None:
|
||||
continue
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_vocab_size = config.vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
||||
output_embeds = model.get_output_embeddings()
|
||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||
# Check bias if present
|
||||
if output_embeds.bias is not None:
|
||||
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model.resize_token_embeddings(model_vocab_size - 15)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
output_embeds = model.get_output_embeddings()
|
||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
|
||||
# Check bias if present
|
||||
if output_embeds.bias is not None:
|
||||
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
|
||||
@require_torch
|
||||
class MoonshineModelIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor_tiny = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny")
|
||||
self.processor_base = AutoProcessor.from_pretrained("UsefulSensors/moonshine-base")
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
@slow
|
||||
def test_tiny_logits_single(self):
|
||||
model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self.processor_tiny(self._load_datasamples(1), return_tensors="pt")
|
||||
inputs.to(torch_device)
|
||||
outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor([
|
||||
-9.1107, 4.5538, 6.3902, -6.8141, -7.2459, -7.9077, -7.2842, -7.6045, -8.0387, -7.8354,
|
||||
-7.3870, -7.2453, -7.7423, -7.3914, -7.3869, -7.6982, -7.6422, -7.0507, -7.3982, -7.2486,
|
||||
-8.0799, -7.3303, -7.3675, -6.8769, -7.6879, -7.2684, -6.9868, -6.7459, -7.6858, -7.3052,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(outputs.logits[0][0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_base_logits_single(self):
|
||||
model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-base")
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self.processor_base(self._load_datasamples(1), return_tensors="pt")
|
||||
inputs.to(torch_device)
|
||||
outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor([
|
||||
-6.7340, 1.9483, 5.2449, -8.0277, -7.9167, -7.8956, -7.9649, -7.9348, -8.1312, -8.0616,
|
||||
-8.1070, -7.7696, -7.8809, -7.9451, -8.1013, -7.8177, -7.8598, -7.8257, -7.8729, -7.9657,
|
||||
-7.9310, -8.1024, -7.8698, -7.8231, -8.0752, -7.9764, -7.8127, -8.0536, -7.9492, -7.9289,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(outputs.logits[0][0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_tiny_logits_batch(self):
|
||||
model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self.processor_tiny(self._load_datasamples(4), return_tensors="pt", padding=True)
|
||||
inputs.to(torch_device)
|
||||
outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True)
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor([
|
||||
[-8.0098, 5.0239, 4.5986, -6.8125, -7.1676, -7.8782, -7.2152, -7.5188, -7.9078, -7.7394],
|
||||
[-4.4394, -1.4429, 6.6715, -6.8927, -7.3748, -7.0967, -6.5255, -7.0255, -7.2583, -7.0007],
|
||||
[-10.0088, 3.2862, 0.7342, -6.5558, -6.8514, -6.5309, -6.4173, -6.9485, -6.6215, -6.6230],
|
||||
[-10.8083, 4.0034, -0.0635, -5.0501, -5.3903, -5.4587, -5.2416, -5.4742, -5.2662, -5.3154]
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(outputs.logits[0][:, :10].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_base_logits_batch(self):
|
||||
model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-base")
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self.processor_base(self._load_datasamples(4), return_tensors="pt", padding=True)
|
||||
inputs.to(torch_device)
|
||||
outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor([
|
||||
[-7.7288, 1.4636, 5.2273, -7.7310, -7.6249, -7.6009, -7.6786, -7.6438, -7.8450, -7.7546],
|
||||
[-6.2161, -0.5891, 7.9489, -7.0693, -6.9996, -6.9980, -7.0952, -7.0830, -7.1685, -7.0136],
|
||||
[-7.3186, 3.1192, 3.8938, -5.7208, -5.8429, -5.7610, -5.9997, -5.8213, -5.8616, -5.8720],
|
||||
[-9.5488, 1.0147, 4.1174, -5.9972, -6.0616, -6.0331, -6.2105, -6.0320, -6.0791, -6.0875]
|
||||
])
|
||||
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(outputs.logits[0][:, :10].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_tiny_generation_single(self):
|
||||
model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
|
||||
model.to(torch_device)
|
||||
|
||||
audio_array = self._load_datasamples(1)
|
||||
inputs = self.processor_tiny(audio_array, return_tensors="pt")
|
||||
inputs.to(torch_device)
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=20)
|
||||
transcript = self.processor_tiny.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_base_generation_single(self):
|
||||
model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-base")
|
||||
model.to(torch_device)
|
||||
|
||||
audio_array = self._load_datasamples(1)
|
||||
inputs = self.processor_base(audio_array, return_tensors="pt")
|
||||
inputs.to(torch_device)
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=20)
|
||||
transcript = self.processor_base.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_generation_batch(self):
|
||||
model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
|
||||
model.to(torch_device)
|
||||
|
||||
audio_array = self._load_datasamples(4)
|
||||
inputs = self.processor_tiny(audio_array, return_tensors="pt", padding=True)
|
||||
inputs.to(torch_device)
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=20)
|
||||
transcript = self.processor_tiny.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
"Mr. Quilter is the apostle of the middle classes, and we are glad to welcome",
|
||||
"Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
"He tells us that at this festive season of the year, with Christmas and Rose beef lo",
|
||||
"He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_base_generation_batch(self):
|
||||
model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-base")
|
||||
model.to(torch_device)
|
||||
|
||||
audio_array = self._load_datasamples(4)
|
||||
inputs = self.processor_base(audio_array, return_tensors="pt", padding=True)
|
||||
inputs.to(torch_device)
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=20)
|
||||
transcript = self.processor_base.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
"Mr. Quilter is the apostle of the middle classes, and we are glad to welcome",
|
||||
"Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
"He tells us that at this festive season of the year, with Christmas and rose beef lo",
|
||||
"He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
Loading…
Reference in New Issue
Block a user