mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Add FlaxWhisperForAudioClassification model (#23173)
* Add FlaxWhisperForAudioClassification model * Add models to init * Add models to init * Fix copies * Fix automapping * Fix failing test
This commit is contained in:
parent
fc6c8b0eaa
commit
312b104ff6
@ -105,3 +105,9 @@ The original code can be found [here](https://github.com/openai/whisper).
|
|||||||
|
|
||||||
[[autodoc]] FlaxWhisperForConditionalGeneration
|
[[autodoc]] FlaxWhisperForConditionalGeneration
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
## FlaxWhisperForAudioClassification
|
||||||
|
|
||||||
|
[[autodoc]] FlaxWhisperForAudioClassification
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
@ -3779,6 +3779,7 @@ else:
|
|||||||
"FlaxWhisperForConditionalGeneration",
|
"FlaxWhisperForConditionalGeneration",
|
||||||
"FlaxWhisperModel",
|
"FlaxWhisperModel",
|
||||||
"FlaxWhisperPreTrainedModel",
|
"FlaxWhisperPreTrainedModel",
|
||||||
|
"FlaxWhisperForAudioClassification",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.xglm"].extend(
|
_import_structure["models.xglm"].extend(
|
||||||
@ -6903,7 +6904,12 @@ if TYPE_CHECKING:
|
|||||||
FlaxWav2Vec2Model,
|
FlaxWav2Vec2Model,
|
||||||
FlaxWav2Vec2PreTrainedModel,
|
FlaxWav2Vec2PreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel
|
from .models.whisper import (
|
||||||
|
FlaxWhisperForAudioClassification,
|
||||||
|
FlaxWhisperForConditionalGeneration,
|
||||||
|
FlaxWhisperModel,
|
||||||
|
FlaxWhisperPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
|
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
|
||||||
from .models.xlm_roberta import (
|
from .models.xlm_roberta import (
|
||||||
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
@ -229,6 +229,11 @@ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
("whisper", "FlaxWhisperForAudioClassification"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
|
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
|
||||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||||
|
@ -75,6 +75,7 @@ else:
|
|||||||
"FlaxWhisperForConditionalGeneration",
|
"FlaxWhisperForConditionalGeneration",
|
||||||
"FlaxWhisperModel",
|
"FlaxWhisperModel",
|
||||||
"FlaxWhisperPreTrainedModel",
|
"FlaxWhisperPreTrainedModel",
|
||||||
|
"FlaxWhisperForAudioClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -126,6 +127,7 @@ if TYPE_CHECKING:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
from .modeling_flax_whisper import (
|
from .modeling_flax_whisper import (
|
||||||
|
FlaxWhisperForAudioClassification,
|
||||||
FlaxWhisperForConditionalGeneration,
|
FlaxWhisperForConditionalGeneration,
|
||||||
FlaxWhisperModel,
|
FlaxWhisperModel,
|
||||||
FlaxWhisperPreTrainedModel,
|
FlaxWhisperPreTrainedModel,
|
||||||
|
@ -36,6 +36,7 @@ from ...modeling_flax_outputs import (
|
|||||||
FlaxCausalLMOutputWithCrossAttentions,
|
FlaxCausalLMOutputWithCrossAttentions,
|
||||||
FlaxSeq2SeqLMOutput,
|
FlaxSeq2SeqLMOutput,
|
||||||
FlaxSeq2SeqModelOutput,
|
FlaxSeq2SeqModelOutput,
|
||||||
|
FlaxSequenceClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_flax_utils import (
|
from ...modeling_flax_utils import (
|
||||||
ACT2FN,
|
ACT2FN,
|
||||||
@ -1506,3 +1507,163 @@ overwrite_call_docstring(
|
|||||||
append_replace_return_docstrings(
|
append_replace_return_docstrings(
|
||||||
FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
|
FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxWhisperForAudioClassificationModule(nn.Module):
|
||||||
|
config: WhisperConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
gradient_checkpointing: bool = False
|
||||||
|
|
||||||
|
def setup(self) -> None:
|
||||||
|
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
|
||||||
|
self.config.is_encoder_decoder = False
|
||||||
|
num_layers = self.config.num_hidden_layers + 1
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
self.layer_weights = jnp.repeat(1 / num_layers, num_layers)
|
||||||
|
self.projector = nn.Dense(self.config.classifier_proj_size, dtype=self.dtype)
|
||||||
|
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_features,
|
||||||
|
encoder_outputs=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states: bool = True,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if encoder_outputs is None:
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
input_features,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
hidden_states = jnp.stack(encoder_outputs, axis=1)
|
||||||
|
norm_weights = jax.nn.softmax(self.layer_weights, axis=-1)
|
||||||
|
hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1)
|
||||||
|
else:
|
||||||
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.projector(hidden_states)
|
||||||
|
pooled_output = jnp.mean(hidden_states, axis=1)
|
||||||
|
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (logits,) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return FlaxSequenceClassifierOutput(
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING)
|
||||||
|
class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel):
|
||||||
|
module_class = FlaxWhisperForAudioClassificationModule
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
|
# init input tensors
|
||||||
|
input_features = jnp.zeros(input_shape, dtype="f4")
|
||||||
|
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
|
||||||
|
|
||||||
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
|
random_params = self.module.init(
|
||||||
|
rngs,
|
||||||
|
input_features=input_features,
|
||||||
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_features: jnp.ndarray,
|
||||||
|
attention_mask: Optional[jnp.ndarray] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
train: bool = False,
|
||||||
|
params: dict = None,
|
||||||
|
dropout_rng: PRNGKey = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
|
|
||||||
|
# Handle any PRNG if needed
|
||||||
|
rngs = {}
|
||||||
|
if dropout_rng is not None:
|
||||||
|
rngs["dropout"] = dropout_rng
|
||||||
|
|
||||||
|
return self.module.apply(
|
||||||
|
{"params": params or self.params},
|
||||||
|
input_features=jnp.array(input_features, dtype="f4"),
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
rngs=rngs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Transcription example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import jax.numpy as jnp
|
||||||
|
>>> from transformers import AutoFeatureExtractor, FlaxWhisperForAudioClassification
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
|
||||||
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
|
||||||
|
>>> model = FlaxWhisperForAudioClassification.from_pretrained(
|
||||||
|
... "sanchit-gandhi/whisper-medium-fleurs-lang-id", from_pt=True
|
||||||
|
... )
|
||||||
|
>>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
|
||||||
|
|
||||||
|
>>> sample = next(iter(ds))
|
||||||
|
|
||||||
|
>>> inputs = feature_extractor(
|
||||||
|
... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="np"
|
||||||
|
... )
|
||||||
|
>>> input_features = inputs.input_features
|
||||||
|
|
||||||
|
>>> logits = model(input_features).logits
|
||||||
|
|
||||||
|
>>> predicted_class_ids = jnp.argmax(logits).item()
|
||||||
|
>>> predicted_label = model.config.id2label[predicted_class_ids]
|
||||||
|
>>> predicted_label
|
||||||
|
'af_za'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
overwrite_call_docstring(
|
||||||
|
FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING
|
||||||
|
)
|
||||||
|
append_replace_return_docstrings(
|
||||||
|
FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
@ -1182,6 +1182,13 @@ class FlaxWav2Vec2PreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxWhisperForAudioClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxWhisperForConditionalGeneration(metaclass=DummyObject):
|
class FlaxWhisperForConditionalGeneration(metaclass=DummyObject):
|
||||||
_backends = ["flax"]
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -41,6 +39,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
FLAX_MODEL_MAPPING,
|
FLAX_MODEL_MAPPING,
|
||||||
|
FlaxWhisperForAudioClassification,
|
||||||
FlaxWhisperForConditionalGeneration,
|
FlaxWhisperForConditionalGeneration,
|
||||||
FlaxWhisperModel,
|
FlaxWhisperModel,
|
||||||
WhisperFeatureExtractor,
|
WhisperFeatureExtractor,
|
||||||
@ -704,3 +703,205 @@ class FlaxWhisperModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxWhisperEncoderModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=60,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=True,
|
||||||
|
hidden_size=16,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
input_channels=1,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=20,
|
||||||
|
max_source_positions=30,
|
||||||
|
num_mel_bins=80,
|
||||||
|
num_conv_layers=1,
|
||||||
|
suppress_tokens=None,
|
||||||
|
begin_suppress_tokens=None,
|
||||||
|
classifier_proj_size=4,
|
||||||
|
num_labels=2,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
is_decoder=False,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.input_channels = input_channels
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.num_mel_bins = num_mel_bins
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.max_source_positions = max_source_positions
|
||||||
|
self.num_conv_layers = num_conv_layers
|
||||||
|
self.suppress_tokens = suppress_tokens
|
||||||
|
self.begin_suppress_tokens = begin_suppress_tokens
|
||||||
|
self.classifier_proj_size = classifier_proj_size
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
|
self.is_decoder = is_decoder
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return WhisperConfig(
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
encoder_layers=self.num_hidden_layers,
|
||||||
|
decoder_layers=self.num_hidden_layers,
|
||||||
|
encoder_attention_heads=self.num_attention_heads,
|
||||||
|
decoder_attention_heads=self.num_attention_heads,
|
||||||
|
input_channels=self.input_channels,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
max_source_positions=self.max_source_positions,
|
||||||
|
decoder_ffn_dim=self.hidden_size,
|
||||||
|
encoder_ffn_dim=self.hidden_size,
|
||||||
|
suppress_tokens=self.suppress_tokens,
|
||||||
|
begin_suppress_tokens=self.begin_suppress_tokens,
|
||||||
|
classifier_proj_size=self.classifier_proj_size,
|
||||||
|
num_labels=self.num_labels,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
is_decoder=self.is_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_whisper_encoder_inputs_dict(
|
||||||
|
self,
|
||||||
|
input_features,
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"input_features": input_features,
|
||||||
|
}
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length])
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
inputs_dict = self.prepare_whisper_encoder_inputs_dict(
|
||||||
|
input_features=input_features,
|
||||||
|
)
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config, inputs_dict = self.prepare_config_and_inputs()
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def get_subsampled_output_lengths(self, input_lengths):
|
||||||
|
"""
|
||||||
|
Computes the output length of the convolutional layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
for i in range(self.num_conv_layers):
|
||||||
|
input_lengths = (input_lengths - 1) // 2 + 1
|
||||||
|
|
||||||
|
return input_lengths
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoder_seq_length(self):
|
||||||
|
return self.get_subsampled_output_lengths(self.seq_length)
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class WhisperEncoderModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (FlaxWhisperForAudioClassification,) if is_flax_available() else ()
|
||||||
|
is_encoder_decoder = False
|
||||||
|
fx_compatible = False
|
||||||
|
test_pruning = False
|
||||||
|
test_missing_keys = False
|
||||||
|
|
||||||
|
input_name = "input_features"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = FlaxWhisperEncoderModelTester(self)
|
||||||
|
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
self.init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||||
|
|
||||||
|
self.all_model_classes = (
|
||||||
|
make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes
|
||||||
|
)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
# overwrite because of `input_features`
|
||||||
|
def test_jit_compilation(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def model_jitted(input_features, **kwargs):
|
||||||
|
return model(input_features=input_features, **kwargs)
|
||||||
|
|
||||||
|
with self.subTest("JIT Enabled"):
|
||||||
|
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
|
with self.subTest("JIT Disabled"):
|
||||||
|
with jax.disable_jit():
|
||||||
|
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||||
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
|
# overwrite because of `input_features`
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.__call__)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["input_features", "attention_mask", "output_attentions"]
|
||||||
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
|
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# WhisperEncoder does not have any base model
|
||||||
|
def test_save_load_to_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# WhisperEncoder does not have any base model
|
||||||
|
def test_save_load_from_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# WhisperEncoder does not have any base model
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_from_base_pt(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# WhisperEncoder does not have any base model
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_to_base_pt(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# WhisperEncoder does not have any base model
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_bf16_to_base_pt(self):
|
||||||
|
pass
|
||||||
|
@ -95,7 +95,7 @@ class WhisperModelTester:
|
|||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=13,
|
||||||
seq_length=60,
|
seq_length=1500,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_labels=False,
|
use_labels=False,
|
||||||
vocab_size=200,
|
vocab_size=200,
|
||||||
@ -107,7 +107,7 @@ class WhisperModelTester:
|
|||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=20,
|
max_position_embeddings=20,
|
||||||
max_source_positions=30,
|
max_source_positions=750,
|
||||||
max_target_positions=40,
|
max_target_positions=40,
|
||||||
bos_token_id=98,
|
bos_token_id=98,
|
||||||
eos_token_id=98,
|
eos_token_id=98,
|
||||||
@ -1434,7 +1434,7 @@ class WhisperEncoderModelTester:
|
|||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=13,
|
||||||
seq_length=60,
|
seq_length=3000,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
hidden_size=16,
|
hidden_size=16,
|
||||||
@ -1445,7 +1445,7 @@ class WhisperEncoderModelTester:
|
|||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=20,
|
max_position_embeddings=20,
|
||||||
max_source_positions=30,
|
max_source_positions=1500,
|
||||||
num_mel_bins=80,
|
num_mel_bins=80,
|
||||||
num_conv_layers=1,
|
num_conv_layers=1,
|
||||||
suppress_tokens=None,
|
suppress_tokens=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user