Revert "Add FlaxWhisperForAudioClassification model" (#23154)

Revert "Add FlaxWhisperForAudioClassification model (#22883)"

This reverts commit c8f2c5c56e.
This commit is contained in:
Sylvain Gugger 2023-05-04 13:47:07 -04:00 committed by GitHub
parent b369e507aa
commit 01734dba84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 3 additions and 390 deletions

View File

@ -105,9 +105,3 @@ The original code can be found [here](https://github.com/openai/whisper).
[[autodoc]] FlaxWhisperForConditionalGeneration
- __call__
## FlaxWhisperForAudioClassification
[[autodoc]] FlaxWhisperForAudioClassification
- __call__

View File

@ -3779,7 +3779,6 @@ else:
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
"FlaxWhisperForAudioClassification",
]
)
_import_structure["models.xglm"].extend(
@ -6904,12 +6903,7 @@ if TYPE_CHECKING:
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
from .models.whisper import (
FlaxWhisperForAudioClassification,
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
)
from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
from .models.xlm_roberta import (
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,

View File

@ -229,11 +229,6 @@ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
]
)
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("whisper", "FlaxWhisperForAudioClassification"),
]
)
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)

View File

@ -75,7 +75,6 @@ else:
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
"FlaxWhisperForAudioClassification",
]
@ -127,7 +126,6 @@ if TYPE_CHECKING:
pass
else:
from .modeling_flax_whisper import (
FlaxWhisperForAudioClassification,
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,

View File

@ -36,7 +36,6 @@ from ...modeling_flax_outputs import (
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
FlaxSequenceClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
@ -1507,162 +1506,3 @@ overwrite_call_docstring(
append_replace_return_docstrings(
FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
class FlaxWhisperForAudioClassificationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
self.config.is_encoder_decoder = False
num_layers = self.config.num_hidden_layers + 1
if self.config.use_weighted_layer_sum:
self.layer_weights = jnp.repeat(1 / num_layers, num_layers)
self.projector = nn.Dense(self.config.classifier_proj_size, dtype=self.dtype)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self,
input_features,
encoder_outputs=None,
output_attentions=None,
output_hidden_states: bool = True,
return_dict: bool = True,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = jnp.stack(encoder_outputs, axis=1)
norm_weights = jax.nn.softmax(self.layer_weights, axis=-1)
hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1)
else:
hidden_states = encoder_outputs[0]
hidden_states = self.projector(hidden_states)
pooled_output = jnp.mean(hidden_states, axis=1)
logits = self.classifier(pooled_output)
if not return_dict:
return (logits,) + encoder_outputs[1:]
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING)
class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel):
module_class = FlaxWhisperForAudioClassificationModule
dtype: jnp.dtype = jnp.float32
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs,
input_features=input_features,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
def __call__(
self,
input_features: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
**kwargs,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
input_features=jnp.array(input_features, dtype="f4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
)
FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r"""
Returns:
Transcription example:
```python
>>> import jax.numpy as jnp
>>> from transformers import AutoFeatureExtractor, FlaxWhisperForAudioClassification
>>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
>>> model = FlaxWhisperForAudioClassification.from_pretrained(
... "sanchit-gandhi/whisper-medium-fleurs-lang-id", from_pt=True
... )
>>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
>>> sample = next(iter(ds))
>>> inputs = feature_extractor(
... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="np"
... )
>>> input_features = inputs.input_features
>>> logits = model(input_features).logits
>>> predicted_class_ids = jnp.argmax(logits).item()
>>> predicted_label = model.config.id2label[predicted_class_ids]
>>> predicted_label
'af_za'
```
"""
overwrite_call_docstring(
FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING
)
append_replace_return_docstrings(
FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC
)

View File

@ -1182,13 +1182,6 @@ class FlaxWav2Vec2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxWhisperForAudioClassification(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxWhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import tempfile
@ -39,7 +41,6 @@ if is_flax_available():
from transformers import (
FLAX_MODEL_MAPPING,
FlaxWhisperForAudioClassification,
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
WhisperFeatureExtractor,
@ -703,205 +704,3 @@ class FlaxWhisperModelIntegrationTest(unittest.TestCase):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
class FlaxWhisperEncoderModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=60,
is_training=True,
use_labels=True,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
input_channels=1,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
max_source_positions=30,
num_mel_bins=80,
num_conv_layers=1,
suppress_tokens=None,
begin_suppress_tokens=None,
classifier_proj_size=4,
num_labels=2,
is_encoder_decoder=False,
is_decoder=False,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.input_channels = input_channels
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.num_mel_bins = num_mel_bins
self.max_position_embeddings = max_position_embeddings
self.max_source_positions = max_source_positions
self.num_conv_layers = num_conv_layers
self.suppress_tokens = suppress_tokens
self.begin_suppress_tokens = begin_suppress_tokens
self.classifier_proj_size = classifier_proj_size
self.num_labels = num_labels
self.is_encoder_decoder = is_encoder_decoder
self.is_decoder = is_decoder
def get_config(self):
return WhisperConfig(
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
input_channels=self.input_channels,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
max_source_positions=self.max_source_positions,
decoder_ffn_dim=self.hidden_size,
encoder_ffn_dim=self.hidden_size,
suppress_tokens=self.suppress_tokens,
begin_suppress_tokens=self.begin_suppress_tokens,
classifier_proj_size=self.classifier_proj_size,
num_labels=self.num_labels,
is_encoder_decoder=self.is_encoder_decoder,
is_decoder=self.is_decoder,
)
def prepare_whisper_encoder_inputs_dict(
self,
input_features,
):
return {
"input_features": input_features,
}
def prepare_config_and_inputs(self):
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length])
config = self.get_config()
inputs_dict = self.prepare_whisper_encoder_inputs_dict(
input_features=input_features,
)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def get_subsampled_output_lengths(self, input_lengths):
"""
Computes the output length of the convolutional layers
"""
for i in range(self.num_conv_layers):
input_lengths = (input_lengths - 1) // 2 + 1
return input_lengths
@property
def encoder_seq_length(self):
return self.get_subsampled_output_lengths(self.seq_length)
@require_flax
class WhisperEncoderModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxWhisperForAudioClassification,) if is_flax_available() else ()
is_encoder_decoder = False
fx_compatible = False
test_pruning = False
test_missing_keys = False
input_name = "input_features"
def setUp(self):
self.model_tester = FlaxWhisperEncoderModelTester(self)
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self.init_shape = (1,) + inputs_dict["input_features"].shape[1:]
self.all_model_classes = (
make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes
)
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
def test_config(self):
self.config_tester.run_common_tests()
# overwrite because of `input_features`
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def model_jitted(input_features, **kwargs):
return model(input_features=input_features, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
# overwrite because of `input_features`
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_features", "attention_mask", "output_attentions"]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
def test_inputs_embeds(self):
pass
# WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented
def test_model_common_attributes(self):
pass
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
def test_resize_tokens_embeddings(self):
pass
# WhisperEncoder does not have any base model
def test_save_load_to_base(self):
pass
# WhisperEncoder does not have any base model
def test_save_load_from_base(self):
pass
# WhisperEncoder does not have any base model
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
pass
# WhisperEncoder does not have any base model
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
pass
# WhisperEncoder does not have any base model
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
pass